mirror of
https://github.com/twofas/2fas-server.git
synced 2024-12-05 00:29:56 +01:00
feat: connect pass with kms (#29)
This commit is contained in:
parent
c00c8a4d5b
commit
782e77173d
4
.env
4
.env
@ -15,3 +15,7 @@ SECURITY_RATE_LIMIT_BE=100
|
|||||||
SECURITY_RATE_LIMIT_MOBILE=100
|
SECURITY_RATE_LIMIT_MOBILE=100
|
||||||
|
|
||||||
PASS_ADDR=:8082
|
PASS_ADDR=:8082
|
||||||
|
|
||||||
|
AWS_ACCESS_KEY_ID=test
|
||||||
|
AWS_SECRET_ACCESS_KEY=test
|
||||||
|
AWS_ENDPOINT="http://localhost:4566"
|
||||||
|
1
Makefile
1
Makefile
@ -32,6 +32,7 @@ tests-e2e: ## run end to end tests
|
|||||||
go test ./tests/mobile/... -count=1
|
go test ./tests/mobile/... -count=1
|
||||||
go test ./tests/support/... -count=1
|
go test ./tests/support/... -count=1
|
||||||
go test ./tests/system/... -count=1
|
go test ./tests/system/... -count=1
|
||||||
|
go test ./tests/pass/... -count=1
|
||||||
|
|
||||||
|
|
||||||
vendor-licenses: ## report vendor licenses
|
vendor-licenses: ## report vendor licenses
|
||||||
|
@ -17,7 +17,7 @@ func main() {
|
|||||||
logging.Fatal(err.Error())
|
logging.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
server := pass.NewServer(cfg.Addr)
|
server := pass.NewServer(cfg)
|
||||||
|
|
||||||
if err := server.Run(); err != nil {
|
if err := server.Run(); err != nil {
|
||||||
logging.Fatal(err.Error())
|
logging.Fatal(err.Error())
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
type PassConfig struct {
|
type PassConfig struct {
|
||||||
Addr string `envconfig:"PASS_ADDR" default:":8082"`
|
Addr string `envconfig:"PASS_ADDR" default:":8082"`
|
||||||
|
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
|
||||||
|
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
|
||||||
}
|
}
|
||||||
|
@ -90,8 +90,14 @@ services:
|
|||||||
- '1000'
|
- '1000'
|
||||||
ports:
|
ports:
|
||||||
- "8084:8082"
|
- "8084:8082"
|
||||||
|
environment:
|
||||||
|
# overwrite AWS_ENDPOINT from .env file. One in env is used to running app from local also.
|
||||||
|
AWS_ENDPOINT: http://localstack-main:4566
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
|
depends_on:
|
||||||
|
localstack:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
localstack:
|
localstack:
|
||||||
container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}"
|
container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}"
|
||||||
@ -100,6 +106,12 @@ services:
|
|||||||
- "127.0.0.1:4566:4566"
|
- "127.0.0.1:4566:4566"
|
||||||
environment:
|
environment:
|
||||||
- DEBUG=1
|
- DEBUG=1
|
||||||
|
healthcheck:
|
||||||
|
test: >-
|
||||||
|
curl -s localhost:4566/_localstack/health | grep -q '"kms": "running"'
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
volumes:
|
volumes:
|
||||||
- "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook
|
- "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook
|
||||||
- "./data/localstack:/var/lib/localstack"
|
- "./data/localstack:/var/lib/localstack"
|
||||||
|
@ -3,12 +3,17 @@ package pairing
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||||
)
|
)
|
||||||
|
|
||||||
// VerifyPairingToken verifies pairing token and returns extension_id
|
// VerifyPairingToken verifies pairing token and returns extension_id
|
||||||
func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) {
|
func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) {
|
||||||
// TODO verify pairing token and take extension from token, this is for debug only.
|
extensionID, err := p.signSvc.CanI(pairingToken, sign.ConnectionTypeBrowserExtensionWait)
|
||||||
extensionID := pairingToken
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||||
|
}
|
||||||
ok := p.store.ExtensionExists(ctx, extensionID)
|
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("extension is not configured")
|
return "", errors.New("extension is not configured")
|
||||||
@ -16,10 +21,25 @@ func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (
|
|||||||
return extensionID, nil
|
return extensionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyProxyToken verifies proxy token and returns extension_id
|
// VerifyExtProxyToken verifies proxy token and returns extension_id
|
||||||
func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (string, error) {
|
func (p *Pairing) VerifyExtProxyToken(ctx context.Context, proxyToken string) (string, error) {
|
||||||
// TODO verify proxy token and take extension from token, this is for debug only.
|
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionProxy)
|
||||||
extensionID := proxyToken
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||||
|
}
|
||||||
|
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("extension is not configured")
|
||||||
|
}
|
||||||
|
return extensionID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyMobileProxyToken verifies mobile token and returns extension_id
|
||||||
|
func (p *Pairing) VerifyMobileProxyToken(ctx context.Context, proxyToken string) (string, error) {
|
||||||
|
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileProxy)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||||
|
}
|
||||||
ok := p.store.ExtensionExists(ctx, extensionID)
|
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("extension is not configured")
|
return "", errors.New("extension is not configured")
|
||||||
@ -29,8 +49,10 @@ func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (stri
|
|||||||
|
|
||||||
// VerifyConnectionToken verifies connection token and returns extension_id
|
// VerifyConnectionToken verifies connection token and returns extension_id
|
||||||
func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) {
|
func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) {
|
||||||
// TODO verify proxy token and take extension from token, this is for debug only.
|
extensionID, err := p.signSvc.CanI(connectionToken, sign.ConnectionTypeMobileConfirm)
|
||||||
extensionID := connectionToken
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||||
|
}
|
||||||
ok := p.store.ExtensionExists(ctx, extensionID)
|
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("extension is not configured")
|
return "", errors.New("extension is not configured")
|
||||||
|
@ -61,7 +61,7 @@ func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFu
|
|||||||
gCtx.Status(http.StatusForbidden)
|
gCtx.Status(http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
extensionID, err := pairingApp.VerifyProxyToken(gCtx, token)
|
extensionID, err := pairingApp.VerifyExtProxyToken(gCtx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to verify proxy token: %v", err)
|
logging.Errorf("Failed to verify proxy token: %v", err)
|
||||||
gCtx.Status(http.StatusInternalServerError)
|
gCtx.Status(http.StatusInternalServerError)
|
||||||
@ -106,11 +106,13 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pairingApp.ConfirmPairing(gCtx, req, extensionID); err != nil {
|
resp, err := pairingApp.ConfirmPairing(gCtx, req, extensionID)
|
||||||
|
if err != nil {
|
||||||
logging.Errorf("Failed to ConfirmPairing: %v", err)
|
logging.Errorf("Failed to ConfirmPairing: %v", err)
|
||||||
gCtx.Status(http.StatusInternalServerError)
|
gCtx.Status(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
gCtx.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +124,7 @@ func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc
|
|||||||
gCtx.Status(http.StatusForbidden)
|
gCtx.Status(http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
extensionID, err := pairingApp.VerifyConnectionToken(gCtx, token)
|
extensionID, err := pairingApp.VerifyMobileProxyToken(gCtx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to verify connection token: %v", err)
|
logging.Errorf("Failed to verify connection token: %v", err)
|
||||||
gCtx.Status(http.StatusInternalServerError)
|
gCtx.Status(http.StatusInternalServerError)
|
||||||
|
@ -9,10 +9,12 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/twofas/2fas-server/internal/common/logging"
|
"github.com/twofas/2fas-server/internal/common/logging"
|
||||||
|
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Pairing struct {
|
type Pairing struct {
|
||||||
store store
|
store store
|
||||||
|
signSvc *sign.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
type store interface {
|
type store interface {
|
||||||
@ -22,12 +24,17 @@ type store interface {
|
|||||||
SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error
|
SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPairingApp() *Pairing {
|
func NewPairingApp(signService *sign.Service) *Pairing {
|
||||||
return &Pairing{
|
return &Pairing{
|
||||||
store: NewMemoryStore(),
|
store: NewMemoryStore(),
|
||||||
|
signSvc: signService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
pairingTokenValidityDuration = 3 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
type ConfigureBrowserExtensionRequest struct {
|
type ConfigureBrowserExtensionRequest struct {
|
||||||
ExtensionID string `json:"extension_id"`
|
ExtensionID string `json:"extension_id"`
|
||||||
}
|
}
|
||||||
@ -39,12 +46,26 @@ type ConfigureBrowserExtensionResponse struct {
|
|||||||
|
|
||||||
func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) {
|
func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) {
|
||||||
p.store.AddExtension(ctx, req.ExtensionID)
|
p.store.AddExtension(ctx, req.ExtensionID)
|
||||||
// TODO: generate connection token and pairing token.
|
|
||||||
connectionToken := req.ExtensionID
|
|
||||||
pairingToken := req.ExtensionID
|
|
||||||
|
|
||||||
|
pairingToken, err := p.signSvc.SignAndEncode(sign.Message{
|
||||||
|
ConnectionID: req.ExtensionID,
|
||||||
|
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
|
||||||
|
ConnectionType: sign.ConnectionTypeBrowserExtensionWait,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to generate pairing token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mobileToken, err := p.signSvc.SignAndEncode(sign.Message{
|
||||||
|
ConnectionID: req.ExtensionID,
|
||||||
|
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
|
||||||
|
ConnectionType: sign.ConnectionTypeMobileConfirm,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("Failed to generate mobile confirm token: %v", err)
|
||||||
|
}
|
||||||
return ConfigureBrowserExtensionResponse{
|
return ConfigureBrowserExtensionResponse{
|
||||||
ConnectionToken: connectionToken,
|
ConnectionToken: mobileToken,
|
||||||
BrowserExtensionPairingToken: pairingToken,
|
BrowserExtensionPairingToken: pairingToken,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -119,10 +140,17 @@ func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
|
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
|
||||||
// generate token here
|
extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
|
||||||
|
ConnectionID: extID,
|
||||||
|
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
|
||||||
|
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to generate ext proxy token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := conn.WriteJSON(WaitForConnectionResponse{
|
if err := conn.WriteJSON(WaitForConnectionResponse{
|
||||||
// TODO: replace with real token.
|
BrowserExtensionProxyToken: extProxyToken,
|
||||||
BrowserExtensionProxyToken: extID,
|
|
||||||
Status: "ok",
|
Status: "ok",
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@ -141,12 +169,27 @@ type ConfirmPairingRequest struct {
|
|||||||
DeviceID string `json:"device_id"`
|
DeviceID string `json:"device_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) error {
|
type ConfirmPairingResponse struct {
|
||||||
return p.store.SetPairingInfo(ctx, extensionID, PairingInfo{
|
ProxyToken string `json:"proxy_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) (ConfirmPairingResponse, error) {
|
||||||
|
mobileProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
|
||||||
|
ConnectionID: extensionID,
|
||||||
|
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
|
||||||
|
ConnectionType: sign.ConnectionTypeMobileProxy,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return ConfirmPairingResponse{}, fmt.Errorf("Failed to generate ext proxy token: %v", err)
|
||||||
|
}
|
||||||
|
if err := p.store.SetPairingInfo(ctx, extensionID, PairingInfo{
|
||||||
Device: MobileDevice{
|
Device: MobileDevice{
|
||||||
DeviceID: req.DeviceID,
|
DeviceID: req.DeviceID,
|
||||||
FCMToken: req.FCMToken,
|
FCMToken: req.FCMToken,
|
||||||
},
|
},
|
||||||
PairedAt: time.Now().UTC(),
|
PairedAt: time.Now().UTC(),
|
||||||
})
|
}); err != nil {
|
||||||
|
return ConfirmPairingResponse{}, err
|
||||||
|
}
|
||||||
|
return ConfirmPairingResponse{ProxyToken: mobileProxyToken}, nil
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,15 @@
|
|||||||
package pass
|
package pass
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
"github.com/aws/aws-sdk-go/service/kms"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||||
|
|
||||||
|
"github.com/twofas/2fas-server/config"
|
||||||
httphelpers "github.com/twofas/2fas-server/internal/common/http"
|
httphelpers "github.com/twofas/2fas-server/internal/common/http"
|
||||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||||
"github.com/twofas/2fas-server/internal/pass/pairing"
|
"github.com/twofas/2fas-server/internal/pass/pairing"
|
||||||
@ -12,8 +20,27 @@ type Server struct {
|
|||||||
addr string
|
addr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(addr string) *Server {
|
func NewServer(cfg config.PassConfig) *Server {
|
||||||
pairingApp := pairing.NewPairingApp()
|
var awsEndpoint *string
|
||||||
|
if cfg.AWSEndpoint != "" {
|
||||||
|
awsEndpoint = aws.String(cfg.AWSEndpoint)
|
||||||
|
}
|
||||||
|
sess, err := session.NewSession(&aws.Config{
|
||||||
|
Region: aws.String("us-east-1"),
|
||||||
|
S3ForcePathStyle: aws.Bool(true),
|
||||||
|
Endpoint: awsEndpoint,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
kmsClient := kms.New(sess)
|
||||||
|
|
||||||
|
signSvc, err := sign.NewService(cfg.KMSKeyID, kmsClient)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pairingApp := pairing.NewPairingApp(signSvc)
|
||||||
proxyApp := pairing.NewProxy()
|
proxyApp := pairing.NewProxy()
|
||||||
|
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
@ -35,7 +62,7 @@ func NewServer(addr string) *Server {
|
|||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
router: router,
|
router: router,
|
||||||
addr: addr,
|
addr: cfg.Addr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,4 +59,5 @@ const (
|
|||||||
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
|
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
|
||||||
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
|
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
|
||||||
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
|
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
|
||||||
|
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
|
||||||
)
|
)
|
||||||
|
@ -46,7 +46,8 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy); err != nil {
|
_, err = srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
|
||||||
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -156,7 +157,7 @@ func TestSignAndVerify(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
token := tc.tokenFn()
|
token := tc.tokenFn()
|
||||||
err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
|
_, err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Expected error %v, got nil", tc.expectedError)
|
t.Fatalf("Expected error %v, got nil", tc.expectedError)
|
||||||
}
|
}
|
||||||
|
@ -3,15 +3,22 @@ package sign
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrInvalidClaims = errors.New("invalid claims")
|
var ErrInvalidClaims = errors.New("invalid claims")
|
||||||
|
|
||||||
|
type customClaims struct {
|
||||||
|
ConnectionID string `json:"c_id"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
// CanI establish connection with type tp given claims in token.
|
// CanI establish connection with type tp given claims in token.
|
||||||
func (s Service) CanI(tokenString string, ct ConnectionType) error {
|
// Returns extension_id from claims if token is valid for given type.
|
||||||
cl := jwt.MapClaims{}
|
func (s Service) CanI(tokenString string, ct ConnectionType) (string, error) {
|
||||||
|
cl := customClaims{}
|
||||||
|
|
||||||
// In Sign we removed `jwtHeader` from JWT before returning it.
|
// In Sign we removed `jwtHeader` from JWT before returning it.
|
||||||
// We need to add it again before doing the verification.
|
// We need to add it again before doing the verification.
|
||||||
@ -27,19 +34,19 @@ func (s Service) CanI(tokenString string, ct ConnectionType) error {
|
|||||||
jwt.WithExpirationRequired(),
|
jwt.WithExpirationRequired(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse token: %w", err)
|
return "", fmt.Errorf("failed to parse token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := token.Claims.GetAudience()
|
audClaims, err := token.Claims.GetAudience()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get claims: %w", err)
|
return "", fmt.Errorf("failed to get claims: %w", err)
|
||||||
}
|
}
|
||||||
|
if !slices.Contains(audClaims, string(ct)) {
|
||||||
for _, aud := range claims {
|
return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct)
|
||||||
if aud == string(ct) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if cl.ConnectionID == "" {
|
||||||
return fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct)
|
return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, "c_id")
|
||||||
|
}
|
||||||
|
// TODO: rename connectionID to extensionID.
|
||||||
|
return cl.ConnectionID, nil
|
||||||
}
|
}
|
||||||
|
@ -43,9 +43,11 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
|
|||||||
t.Log(token)
|
t.Log(token)
|
||||||
t.Log("Length of the token is", len(token))
|
t.Log("Length of the token is", len(token))
|
||||||
|
|
||||||
if err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy); err != nil {
|
extensionID, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
|
||||||
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
t.Log(extensionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSignAndVerify(t *testing.T) {
|
func TestSignAndVerify(t *testing.T) {
|
||||||
@ -139,7 +141,7 @@ func TestSignAndVerify(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
token := tc.tokenFn()
|
token := tc.tokenFn()
|
||||||
err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
|
_, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Expected error %v, got nil", tc.expectedError)
|
t.Fatalf("Expected error %v, got nil", tc.expectedError)
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -24,7 +25,13 @@ var (
|
|||||||
wsDialer = websocket.DefaultDialer
|
wsDialer = websocket.DefaultDialer
|
||||||
)
|
)
|
||||||
|
|
||||||
const api = "localhost:8082"
|
func getAPIURL() string {
|
||||||
|
addr := os.Getenv("PASS_ADDR")
|
||||||
|
if addr != "" {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
return "localhost:8082"
|
||||||
|
}
|
||||||
|
|
||||||
func TestPassHappyFlow(t *testing.T) {
|
func TestPassHappyFlow(t *testing.T) {
|
||||||
resp, err := configureBrowserExtension()
|
resp, err := configureBrowserExtension()
|
||||||
@ -38,15 +45,15 @@ func TestPassHappyFlow(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(browserExtensionDone)
|
defer close(browserExtensionDone)
|
||||||
|
|
||||||
err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
|
extProxyToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
|
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = proxyWebSocket(
|
err = proxyWebSocket(
|
||||||
"ws://"+api+"/browser_extension/proxy_to_mobile",
|
"ws://"+getAPIURL()+"/browser_extension/proxy_to_mobile",
|
||||||
resp.BrowserExtensionPairingToken,
|
extProxyToken,
|
||||||
"sent from browser extension",
|
"sent from browser extension",
|
||||||
"sent from mobile")
|
"sent from mobile")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -58,15 +65,15 @@ func TestPassHappyFlow(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer close(mobileDone)
|
defer close(mobileDone)
|
||||||
|
|
||||||
err := confirmMobile(resp.ConnectionToken)
|
mobileProxyToken, err := confirmMobile(resp.ConnectionToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Mobile: confirm failed: %v", err)
|
t.Errorf("Mobile: confirm failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = proxyWebSocket(
|
err = proxyWebSocket(
|
||||||
"ws://"+api+"/mobile/proxy_to_browser_extension",
|
"ws://"+getAPIURL()+"/mobile/proxy_to_browser_extension",
|
||||||
resp.BrowserExtensionPairingToken,
|
mobileProxyToken,
|
||||||
"sent from mobile",
|
"sent from mobile",
|
||||||
"sent from browser extension",
|
"sent from browser extension",
|
||||||
)
|
)
|
||||||
@ -75,41 +82,42 @@ func TestPassHappyFlow(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
<-browserExtensionDone
|
<-browserExtensionDone
|
||||||
<-mobileDone
|
<-mobileDone
|
||||||
}
|
}
|
||||||
|
|
||||||
func browserExtensionWaitForConfirm(token string) error {
|
func browserExtensionWaitForConfirm(token string) (string, error) {
|
||||||
url := "ws://" + api + "/browser_extension/wait_for_connection"
|
url := "ws://" + getAPIURL() + "/browser_extension/wait_for_connection"
|
||||||
|
|
||||||
var resp struct {
|
var resp struct {
|
||||||
Status string `json:"status"`
|
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
DeviceID string `json:"device_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := dialWS(url, token)
|
conn, err := dialWS(url, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
conn.SetReadDeadline(time.Now().Add(time.Second))
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
_, message, err := conn.ReadMessage()
|
_, message, err := conn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error reading from connection: %w", err)
|
return "", fmt.Errorf("error reading from connection: %w", err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(message, &resp); err != nil {
|
if err := json.Unmarshal(message, &resp); err != nil {
|
||||||
return fmt.Errorf("failed to decode message: %w", err)
|
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||||
}
|
}
|
||||||
const expectedStatus = "ok"
|
const expectedStatus = "ok"
|
||||||
if resp.Status != expectedStatus {
|
if resp.Status != expectedStatus {
|
||||||
return fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
|
return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
|
||||||
}
|
}
|
||||||
return nil
|
return resp.BrowserExtensionProxyToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
|
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
|
||||||
url := "http://" + api + "/browser_extension/configure"
|
url := "http://" + getAPIURL() + "/browser_extension/configure"
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String()))
|
req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -138,26 +146,39 @@ func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func confirmMobile(connectionToken string) error {
|
// confirmMobile confirms pairing and returns mobile proxy token.
|
||||||
url := "http://" + api + "/mobile/confirm"
|
func confirmMobile(connectionToken string) (string, error) {
|
||||||
|
url := "http://" + getAPIURL() + "/mobile/confirm"
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String()))
|
req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to prepare the reqest: %w", err)
|
return "", fmt.Errorf("failed to prepare the reqest: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken))
|
||||||
|
|
||||||
httpResp, err := httpClient.Do(req)
|
httpResp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to perform the reqest: %w", err)
|
return "", fmt.Errorf("failed to perform the reqest: %w", err)
|
||||||
}
|
}
|
||||||
defer httpResp.Body.Close()
|
defer httpResp.Body.Close()
|
||||||
|
|
||||||
if httpResp.StatusCode > 299 {
|
bb, err := io.ReadAll(httpResp.Body)
|
||||||
return fmt.Errorf("unexpected response: %s", httpResp.Status)
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read body from response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
if httpResp.StatusCode >= 300 {
|
||||||
|
return "", fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct {
|
||||||
|
ProxyToken string `json:"proxy_token"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(bb, &resp); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to decode the response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.ProxyToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
|
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
|
||||||
@ -165,7 +186,7 @@ func confirmMobile(connectionToken string) error {
|
|||||||
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
|
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
|
||||||
conn, err := dialWS(url, token)
|
conn, err := dialWS(url, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
@ -204,7 +225,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial ws %q: %v", url, err)
|
return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user