diff --git a/.env b/.env index bf89297..ac39e60 100644 --- a/.env +++ b/.env @@ -15,3 +15,7 @@ SECURITY_RATE_LIMIT_BE=100 SECURITY_RATE_LIMIT_MOBILE=100 PASS_ADDR=:8082 + +AWS_ACCESS_KEY_ID=test +AWS_SECRET_ACCESS_KEY=test +AWS_ENDPOINT="http://localhost:4566" diff --git a/Makefile b/Makefile index 9c271fe..8442fe4 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,7 @@ tests-e2e: ## run end to end tests go test ./tests/mobile/... -count=1 go test ./tests/support/... -count=1 go test ./tests/system/... -count=1 + go test ./tests/pass/... -count=1 vendor-licenses: ## report vendor licenses diff --git a/cmd/pass/main.go b/cmd/pass/main.go index a5f79ee..3c770bd 100644 --- a/cmd/pass/main.go +++ b/cmd/pass/main.go @@ -17,7 +17,7 @@ func main() { logging.Fatal(err.Error()) } - server := pass.NewServer(cfg.Addr) + server := pass.NewServer(cfg) if err := server.Run(); err != nil { logging.Fatal(err.Error()) diff --git a/config/pass_config.go b/config/pass_config.go index a92ffae..37e645b 100644 --- a/config/pass_config.go +++ b/config/pass_config.go @@ -1,5 +1,7 @@ package config type PassConfig struct { - Addr string `envconfig:"PASS_ADDR" default:":8082"` + Addr string `envconfig:"PASS_ADDR" default:":8082"` + KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"` + AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""` } diff --git a/docker-compose.yml b/docker-compose.yml index 4a63fc6..7c4cf75 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -90,8 +90,14 @@ services: - '1000' ports: - "8084:8082" + environment: + # overwrite AWS_ENDPOINT from .env file. One in env is used to running app from local also. + AWS_ENDPOINT: http://localstack-main:4566 env_file: - .env + depends_on: + localstack: + condition: service_healthy localstack: container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}" @@ -100,6 +106,12 @@ services: - "127.0.0.1:4566:4566" environment: - DEBUG=1 + healthcheck: + test: >- + curl -s localhost:4566/_localstack/health | grep -q '"kms": "running"' + interval: 5s + timeout: 5s + retries: 5 volumes: - "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook - "./data/localstack:/var/lib/localstack" diff --git a/internal/pass/pairing/auth.go b/internal/pass/pairing/auth.go index dfcc007..0548f0a 100644 --- a/internal/pass/pairing/auth.go +++ b/internal/pass/pairing/auth.go @@ -3,12 +3,17 @@ package pairing import ( "context" "errors" + "fmt" + + "github.com/twofas/2fas-server/internal/pass/sign" ) // VerifyPairingToken verifies pairing token and returns extension_id func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) { - // TODO verify pairing token and take extension from token, this is for debug only. - extensionID := pairingToken + extensionID, err := p.signSvc.CanI(pairingToken, sign.ConnectionTypeBrowserExtensionWait) + if err != nil { + return "", fmt.Errorf("failed to check token signature: %w", err) + } ok := p.store.ExtensionExists(ctx, extensionID) if !ok { return "", errors.New("extension is not configured") @@ -16,10 +21,25 @@ func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) ( return extensionID, nil } -// VerifyProxyToken verifies proxy token and returns extension_id -func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (string, error) { - // TODO verify proxy token and take extension from token, this is for debug only. - extensionID := proxyToken +// VerifyExtProxyToken verifies proxy token and returns extension_id +func (p *Pairing) VerifyExtProxyToken(ctx context.Context, proxyToken string) (string, error) { + extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionProxy) + if err != nil { + return "", fmt.Errorf("failed to check token signature: %w", err) + } + ok := p.store.ExtensionExists(ctx, extensionID) + if !ok { + return "", errors.New("extension is not configured") + } + return extensionID, nil +} + +// VerifyMobileProxyToken verifies mobile token and returns extension_id +func (p *Pairing) VerifyMobileProxyToken(ctx context.Context, proxyToken string) (string, error) { + extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileProxy) + if err != nil { + return "", fmt.Errorf("failed to check token signature: %w", err) + } ok := p.store.ExtensionExists(ctx, extensionID) if !ok { return "", errors.New("extension is not configured") @@ -29,8 +49,10 @@ func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (stri // VerifyConnectionToken verifies connection token and returns extension_id func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) { - // TODO verify proxy token and take extension from token, this is for debug only. - extensionID := connectionToken + extensionID, err := p.signSvc.CanI(connectionToken, sign.ConnectionTypeMobileConfirm) + if err != nil { + return "", fmt.Errorf("failed to check token signature: %w", err) + } ok := p.store.ExtensionExists(ctx, extensionID) if !ok { return "", errors.New("extension is not configured") diff --git a/internal/pass/pairing/handlers.go b/internal/pass/pairing/handlers.go index 8322f8f..fff72b6 100644 --- a/internal/pass/pairing/handlers.go +++ b/internal/pass/pairing/handlers.go @@ -61,7 +61,7 @@ func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFu gCtx.Status(http.StatusForbidden) return } - extensionID, err := pairingApp.VerifyProxyToken(gCtx, token) + extensionID, err := pairingApp.VerifyExtProxyToken(gCtx, token) if err != nil { logging.Errorf("Failed to verify proxy token: %v", err) gCtx.Status(http.StatusInternalServerError) @@ -106,11 +106,13 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc { return } - if err := pairingApp.ConfirmPairing(gCtx, req, extensionID); err != nil { + resp, err := pairingApp.ConfirmPairing(gCtx, req, extensionID) + if err != nil { logging.Errorf("Failed to ConfirmPairing: %v", err) gCtx.Status(http.StatusInternalServerError) return } + gCtx.JSON(http.StatusOK, resp) } } @@ -122,7 +124,7 @@ func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc gCtx.Status(http.StatusForbidden) return } - extensionID, err := pairingApp.VerifyConnectionToken(gCtx, token) + extensionID, err := pairingApp.VerifyMobileProxyToken(gCtx, token) if err != nil { logging.Errorf("Failed to verify connection token: %v", err) gCtx.Status(http.StatusInternalServerError) diff --git a/internal/pass/pairing/pairing.go b/internal/pass/pairing/pairing.go index 3006f13..53df6f6 100644 --- a/internal/pass/pairing/pairing.go +++ b/internal/pass/pairing/pairing.go @@ -9,10 +9,12 @@ import ( "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "github.com/twofas/2fas-server/internal/common/logging" + "github.com/twofas/2fas-server/internal/pass/sign" ) type Pairing struct { - store store + store store + signSvc *sign.Service } type store interface { @@ -22,12 +24,17 @@ type store interface { SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error } -func NewPairingApp() *Pairing { +func NewPairingApp(signService *sign.Service) *Pairing { return &Pairing{ - store: NewMemoryStore(), + store: NewMemoryStore(), + signSvc: signService, } } +const ( + pairingTokenValidityDuration = 3 * time.Minute +) + type ConfigureBrowserExtensionRequest struct { ExtensionID string `json:"extension_id"` } @@ -39,12 +46,26 @@ type ConfigureBrowserExtensionResponse struct { func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) { p.store.AddExtension(ctx, req.ExtensionID) - // TODO: generate connection token and pairing token. - connectionToken := req.ExtensionID - pairingToken := req.ExtensionID + pairingToken, err := p.signSvc.SignAndEncode(sign.Message{ + ConnectionID: req.ExtensionID, + ExpiresAt: time.Now().Add(pairingTokenValidityDuration), + ConnectionType: sign.ConnectionTypeBrowserExtensionWait, + }) + if err != nil { + return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to generate pairing token: %v", err) + } + + mobileToken, err := p.signSvc.SignAndEncode(sign.Message{ + ConnectionID: req.ExtensionID, + ExpiresAt: time.Now().Add(pairingTokenValidityDuration), + ConnectionType: sign.ConnectionTypeMobileConfirm, + }) + if err != nil { + return ConfigureBrowserExtensionResponse{}, fmt.Errorf("Failed to generate mobile confirm token: %v", err) + } return ConfigureBrowserExtensionResponse{ - ConnectionToken: connectionToken, + ConnectionToken: mobileToken, BrowserExtensionPairingToken: pairingToken, }, nil } @@ -119,10 +140,17 @@ func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logr } func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error { - // generate token here + extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{ + ConnectionID: extID, + ExpiresAt: time.Now().Add(pairingTokenValidityDuration), + ConnectionType: sign.ConnectionTypeBrowserExtensionProxy, + }) + if err != nil { + return fmt.Errorf("Failed to generate ext proxy token: %v", err) + } + if err := conn.WriteJSON(WaitForConnectionResponse{ - // TODO: replace with real token. - BrowserExtensionProxyToken: extID, + BrowserExtensionProxyToken: extProxyToken, Status: "ok", DeviceID: deviceID, }); err != nil { @@ -141,12 +169,27 @@ type ConfirmPairingRequest struct { DeviceID string `json:"device_id"` } -func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) error { - return p.store.SetPairingInfo(ctx, extensionID, PairingInfo{ +type ConfirmPairingResponse struct { + ProxyToken string `json:"proxy_token"` +} + +func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) (ConfirmPairingResponse, error) { + mobileProxyToken, err := p.signSvc.SignAndEncode(sign.Message{ + ConnectionID: extensionID, + ExpiresAt: time.Now().Add(pairingTokenValidityDuration), + ConnectionType: sign.ConnectionTypeMobileProxy, + }) + if err != nil { + return ConfirmPairingResponse{}, fmt.Errorf("Failed to generate ext proxy token: %v", err) + } + if err := p.store.SetPairingInfo(ctx, extensionID, PairingInfo{ Device: MobileDevice{ DeviceID: req.DeviceID, FCMToken: req.FCMToken, }, PairedAt: time.Now().UTC(), - }) + }); err != nil { + return ConfirmPairingResponse{}, err + } + return ConfirmPairingResponse{ProxyToken: mobileProxyToken}, nil } diff --git a/internal/pass/server.go b/internal/pass/server.go index 9018eb1..32db71d 100644 --- a/internal/pass/server.go +++ b/internal/pass/server.go @@ -1,7 +1,15 @@ package pass import ( + "log" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" "github.com/gin-gonic/gin" + "github.com/twofas/2fas-server/internal/pass/sign" + + "github.com/twofas/2fas-server/config" httphelpers "github.com/twofas/2fas-server/internal/common/http" "github.com/twofas/2fas-server/internal/common/recovery" "github.com/twofas/2fas-server/internal/pass/pairing" @@ -12,8 +20,27 @@ type Server struct { addr string } -func NewServer(addr string) *Server { - pairingApp := pairing.NewPairingApp() +func NewServer(cfg config.PassConfig) *Server { + var awsEndpoint *string + if cfg.AWSEndpoint != "" { + awsEndpoint = aws.String(cfg.AWSEndpoint) + } + sess, err := session.NewSession(&aws.Config{ + Region: aws.String("us-east-1"), + S3ForcePathStyle: aws.Bool(true), + Endpoint: awsEndpoint, + }) + if err != nil { + log.Fatal(err) + } + kmsClient := kms.New(sess) + + signSvc, err := sign.NewService(cfg.KMSKeyID, kmsClient) + if err != nil { + log.Fatal(err) + } + + pairingApp := pairing.NewPairingApp(signSvc) proxyApp := pairing.NewProxy() router := gin.New() @@ -35,7 +62,7 @@ func NewServer(addr string) *Server { return &Server{ router: router, - addr: addr, + addr: cfg.Addr, } } diff --git a/internal/pass/sign/lib.go b/internal/pass/sign/lib.go index 0ed639d..e1b7d3d 100644 --- a/internal/pass/sign/lib.go +++ b/internal/pass/sign/lib.go @@ -59,4 +59,5 @@ const ( ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait" ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy" ConnectionTypeMobileProxy ConnectionType = "mobile/proxy" + ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm" ) diff --git a/internal/pass/sign/lib_test.go b/internal/pass/sign/lib_test.go index 223f620..7eabe6f 100644 --- a/internal/pass/sign/lib_test.go +++ b/internal/pass/sign/lib_test.go @@ -46,7 +46,8 @@ func TestSignAndVerifyHappyPath(t *testing.T) { t.Fatal(err) } - if err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy); err != nil { + _, err = srv.CanI(token, ConnectionTypeBrowserExtensionProxy) + if err != nil { t.Fatal(err) } } @@ -156,7 +157,7 @@ func TestSignAndVerify(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { token := tc.tokenFn() - err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy) + _, err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy) if err == nil { t.Fatalf("Expected error %v, got nil", tc.expectedError) } diff --git a/internal/pass/sign/verify.go b/internal/pass/sign/verify.go index 518a211..dbefc36 100644 --- a/internal/pass/sign/verify.go +++ b/internal/pass/sign/verify.go @@ -3,15 +3,22 @@ package sign import ( "errors" "fmt" + "slices" "github.com/golang-jwt/jwt/v5" ) var ErrInvalidClaims = errors.New("invalid claims") +type customClaims struct { + ConnectionID string `json:"c_id"` + jwt.RegisteredClaims +} + // CanI establish connection with type tp given claims in token. -func (s Service) CanI(tokenString string, ct ConnectionType) error { - cl := jwt.MapClaims{} +// Returns extension_id from claims if token is valid for given type. +func (s Service) CanI(tokenString string, ct ConnectionType) (string, error) { + cl := customClaims{} // In Sign we removed `jwtHeader` from JWT before returning it. // We need to add it again before doing the verification. @@ -27,19 +34,19 @@ func (s Service) CanI(tokenString string, ct ConnectionType) error { jwt.WithExpirationRequired(), ) if err != nil { - return fmt.Errorf("failed to parse token: %w", err) + return "", fmt.Errorf("failed to parse token: %w", err) } - claims, err := token.Claims.GetAudience() + audClaims, err := token.Claims.GetAudience() if err != nil { - return fmt.Errorf("failed to get claims: %w", err) + return "", fmt.Errorf("failed to get claims: %w", err) } - - for _, aud := range claims { - if aud == string(ct) { - return nil - } + if !slices.Contains(audClaims, string(ct)) { + return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct) } - - return fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct) + if cl.ConnectionID == "" { + return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, "c_id") + } + // TODO: rename connectionID to extensionID. + return cl.ConnectionID, nil } diff --git a/tests/pass/kms_test.go b/tests/pass/kms_test.go index f11ee27..f00b856 100644 --- a/tests/pass/kms_test.go +++ b/tests/pass/kms_test.go @@ -43,9 +43,11 @@ func TestSignAndVerifyHappyPath(t *testing.T) { t.Log(token) t.Log("Length of the token is", len(token)) - if err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy); err != nil { + extensionID, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy) + if err != nil { t.Fatal(err) } + t.Log(extensionID) } func TestSignAndVerify(t *testing.T) { @@ -139,7 +141,7 @@ func TestSignAndVerify(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { token := tc.tokenFn() - err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy) + _, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy) if err == nil { t.Fatalf("Expected error %v, got nil", tc.expectedError) } diff --git a/tests/pass/pass_test.go b/tests/pass/pass_test.go index ad6ec60..b54320b 100644 --- a/tests/pass/pass_test.go +++ b/tests/pass/pass_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "os" "testing" "time" @@ -24,7 +25,13 @@ var ( wsDialer = websocket.DefaultDialer ) -const api = "localhost:8082" +func getAPIURL() string { + addr := os.Getenv("PASS_ADDR") + if addr != "" { + return addr + } + return "localhost:8082" +} func TestPassHappyFlow(t *testing.T) { resp, err := configureBrowserExtension() @@ -38,15 +45,15 @@ func TestPassHappyFlow(t *testing.T) { go func() { defer close(browserExtensionDone) - err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken) + extProxyToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken) if err != nil { t.Errorf("Error when Browser Extension waited for confirm: %v", err) return } err = proxyWebSocket( - "ws://"+api+"/browser_extension/proxy_to_mobile", - resp.BrowserExtensionPairingToken, + "ws://"+getAPIURL()+"/browser_extension/proxy_to_mobile", + extProxyToken, "sent from browser extension", "sent from mobile") if err != nil { @@ -58,15 +65,15 @@ func TestPassHappyFlow(t *testing.T) { go func() { defer close(mobileDone) - err := confirmMobile(resp.ConnectionToken) + mobileProxyToken, err := confirmMobile(resp.ConnectionToken) if err != nil { t.Errorf("Mobile: confirm failed: %v", err) return } err = proxyWebSocket( - "ws://"+api+"/mobile/proxy_to_browser_extension", - resp.BrowserExtensionPairingToken, + "ws://"+getAPIURL()+"/mobile/proxy_to_browser_extension", + mobileProxyToken, "sent from mobile", "sent from browser extension", ) @@ -75,41 +82,42 @@ func TestPassHappyFlow(t *testing.T) { return } }() - <-browserExtensionDone <-mobileDone } -func browserExtensionWaitForConfirm(token string) error { - url := "ws://" + api + "/browser_extension/wait_for_connection" +func browserExtensionWaitForConfirm(token string) (string, error) { + url := "ws://" + getAPIURL() + "/browser_extension/wait_for_connection" var resp struct { - Status string `json:"status"` + BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"` + Status string `json:"status"` + DeviceID string `json:"device_id"` } conn, err := dialWS(url, token) if err != nil { - return err + return "", err } defer conn.Close() - conn.SetReadDeadline(time.Now().Add(time.Second)) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _, message, err := conn.ReadMessage() if err != nil { - return fmt.Errorf("error reading from connection: %w", err) + return "", fmt.Errorf("error reading from connection: %w", err) } if err := json.Unmarshal(message, &resp); err != nil { - return fmt.Errorf("failed to decode message: %w", err) + return "", fmt.Errorf("failed to decode message: %w", err) } const expectedStatus = "ok" if resp.Status != expectedStatus { - return fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus) + return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus) } - return nil + return resp.BrowserExtensionProxyToken, nil } func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) { - url := "http://" + api + "/browser_extension/configure" + url := "http://" + getAPIURL() + "/browser_extension/configure" req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String())) if err != nil { @@ -138,26 +146,39 @@ func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) { return resp, nil } -func confirmMobile(connectionToken string) error { - url := "http://" + api + "/mobile/confirm" +// confirmMobile confirms pairing and returns mobile proxy token. +func confirmMobile(connectionToken string) (string, error) { + url := "http://" + getAPIURL() + "/mobile/confirm" req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String())) if err != nil { - return fmt.Errorf("failed to prepare the reqest: %w", err) + return "", fmt.Errorf("failed to prepare the reqest: %w", err) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken)) httpResp, err := httpClient.Do(req) if err != nil { - return fmt.Errorf("failed to perform the reqest: %w", err) + return "", fmt.Errorf("failed to perform the reqest: %w", err) } defer httpResp.Body.Close() - if httpResp.StatusCode > 299 { - return fmt.Errorf("unexpected response: %s", httpResp.Status) + bb, err := io.ReadAll(httpResp.Body) + if err != nil { + return "", fmt.Errorf("failed to read body from response: %w", err) } - return nil + if httpResp.StatusCode >= 300 { + return "", fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb)) + } + + var resp struct { + ProxyToken string `json:"proxy_token"` + } + if err := json.Unmarshal(bb, &resp); err != nil { + return "", fmt.Errorf("failed to decode the response: %w", err) + } + + return resp.ProxyToken, nil } // proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and @@ -165,7 +186,7 @@ func confirmMobile(connectionToken string) error { func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error { conn, err := dialWS(url, token) if err != nil { - return nil + return err } defer conn.Close() @@ -204,7 +225,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) { }, }) if err != nil { - return nil, fmt.Errorf("failed to dial ws %q: %v", url, err) + return nil, fmt.Errorf("failed to dial ws %q: %w", url, err) } return conn, nil }