feat: connect pass with kms (#29)

This commit is contained in:
Tobiasz Heller 2024-01-24 20:57:31 +01:00 committed by GitHub
parent c00c8a4d5b
commit 782e77173d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 217 additions and 72 deletions

4
.env
View File

@ -15,3 +15,7 @@ SECURITY_RATE_LIMIT_BE=100
SECURITY_RATE_LIMIT_MOBILE=100
PASS_ADDR=:8082
AWS_ACCESS_KEY_ID=test
AWS_SECRET_ACCESS_KEY=test
AWS_ENDPOINT="http://localhost:4566"

View File

@ -32,6 +32,7 @@ tests-e2e: ## run end to end tests
go test ./tests/mobile/... -count=1
go test ./tests/support/... -count=1
go test ./tests/system/... -count=1
go test ./tests/pass/... -count=1
vendor-licenses: ## report vendor licenses

View File

@ -17,7 +17,7 @@ func main() {
logging.Fatal(err.Error())
}
server := pass.NewServer(cfg.Addr)
server := pass.NewServer(cfg)
if err := server.Run(); err != nil {
logging.Fatal(err.Error())

View File

@ -1,5 +1,7 @@
package config
type PassConfig struct {
Addr string `envconfig:"PASS_ADDR" default:":8082"`
Addr string `envconfig:"PASS_ADDR" default:":8082"`
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
}

View File

@ -90,8 +90,14 @@ services:
- '1000'
ports:
- "8084:8082"
environment:
# overwrite AWS_ENDPOINT from .env file. One in env is used to running app from local also.
AWS_ENDPOINT: http://localstack-main:4566
env_file:
- .env
depends_on:
localstack:
condition: service_healthy
localstack:
container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}"
@ -100,6 +106,12 @@ services:
- "127.0.0.1:4566:4566"
environment:
- DEBUG=1
healthcheck:
test: >-
curl -s localhost:4566/_localstack/health | grep -q '"kms": "running"'
interval: 5s
timeout: 5s
retries: 5
volumes:
- "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook
- "./data/localstack:/var/lib/localstack"

View File

@ -3,12 +3,17 @@ package pairing
import (
"context"
"errors"
"fmt"
"github.com/twofas/2fas-server/internal/pass/sign"
)
// VerifyPairingToken verifies pairing token and returns extension_id
func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) {
// TODO verify pairing token and take extension from token, this is for debug only.
extensionID := pairingToken
extensionID, err := p.signSvc.CanI(pairingToken, sign.ConnectionTypeBrowserExtensionWait)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID)
if !ok {
return "", errors.New("extension is not configured")
@ -16,10 +21,25 @@ func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (
return extensionID, nil
}
// VerifyProxyToken verifies proxy token and returns extension_id
func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (string, error) {
// TODO verify proxy token and take extension from token, this is for debug only.
extensionID := proxyToken
// VerifyExtProxyToken verifies proxy token and returns extension_id
func (p *Pairing) VerifyExtProxyToken(ctx context.Context, proxyToken string) (string, error) {
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionProxy)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID)
if !ok {
return "", errors.New("extension is not configured")
}
return extensionID, nil
}
// VerifyMobileProxyToken verifies mobile token and returns extension_id
func (p *Pairing) VerifyMobileProxyToken(ctx context.Context, proxyToken string) (string, error) {
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileProxy)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID)
if !ok {
return "", errors.New("extension is not configured")
@ -29,8 +49,10 @@ func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (stri
// VerifyConnectionToken verifies connection token and returns extension_id
func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) {
// TODO verify proxy token and take extension from token, this is for debug only.
extensionID := connectionToken
extensionID, err := p.signSvc.CanI(connectionToken, sign.ConnectionTypeMobileConfirm)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID)
if !ok {
return "", errors.New("extension is not configured")

View File

@ -61,7 +61,7 @@ func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFu
gCtx.Status(http.StatusForbidden)
return
}
extensionID, err := pairingApp.VerifyProxyToken(gCtx, token)
extensionID, err := pairingApp.VerifyExtProxyToken(gCtx, token)
if err != nil {
logging.Errorf("Failed to verify proxy token: %v", err)
gCtx.Status(http.StatusInternalServerError)
@ -106,11 +106,13 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
return
}
if err := pairingApp.ConfirmPairing(gCtx, req, extensionID); err != nil {
resp, err := pairingApp.ConfirmPairing(gCtx, req, extensionID)
if err != nil {
logging.Errorf("Failed to ConfirmPairing: %v", err)
gCtx.Status(http.StatusInternalServerError)
return
}
gCtx.JSON(http.StatusOK, resp)
}
}
@ -122,7 +124,7 @@ func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc
gCtx.Status(http.StatusForbidden)
return
}
extensionID, err := pairingApp.VerifyConnectionToken(gCtx, token)
extensionID, err := pairingApp.VerifyMobileProxyToken(gCtx, token)
if err != nil {
logging.Errorf("Failed to verify connection token: %v", err)
gCtx.Status(http.StatusInternalServerError)

View File

@ -9,10 +9,12 @@ import (
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/sign"
)
type Pairing struct {
store store
store store
signSvc *sign.Service
}
type store interface {
@ -22,12 +24,17 @@ type store interface {
SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error
}
func NewPairingApp() *Pairing {
func NewPairingApp(signService *sign.Service) *Pairing {
return &Pairing{
store: NewMemoryStore(),
store: NewMemoryStore(),
signSvc: signService,
}
}
const (
pairingTokenValidityDuration = 3 * time.Minute
)
type ConfigureBrowserExtensionRequest struct {
ExtensionID string `json:"extension_id"`
}
@ -39,12 +46,26 @@ type ConfigureBrowserExtensionResponse struct {
func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) {
p.store.AddExtension(ctx, req.ExtensionID)
// TODO: generate connection token and pairing token.
connectionToken := req.ExtensionID
pairingToken := req.ExtensionID
pairingToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: req.ExtensionID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeBrowserExtensionWait,
})
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to generate pairing token: %v", err)
}
mobileToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: req.ExtensionID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeMobileConfirm,
})
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("Failed to generate mobile confirm token: %v", err)
}
return ConfigureBrowserExtensionResponse{
ConnectionToken: connectionToken,
ConnectionToken: mobileToken,
BrowserExtensionPairingToken: pairingToken,
}, nil
}
@ -119,10 +140,17 @@ func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logr
}
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
// generate token here
extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: extID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
})
if err != nil {
return fmt.Errorf("Failed to generate ext proxy token: %v", err)
}
if err := conn.WriteJSON(WaitForConnectionResponse{
// TODO: replace with real token.
BrowserExtensionProxyToken: extID,
BrowserExtensionProxyToken: extProxyToken,
Status: "ok",
DeviceID: deviceID,
}); err != nil {
@ -141,12 +169,27 @@ type ConfirmPairingRequest struct {
DeviceID string `json:"device_id"`
}
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) error {
return p.store.SetPairingInfo(ctx, extensionID, PairingInfo{
type ConfirmPairingResponse struct {
ProxyToken string `json:"proxy_token"`
}
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) (ConfirmPairingResponse, error) {
mobileProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: extensionID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeMobileProxy,
})
if err != nil {
return ConfirmPairingResponse{}, fmt.Errorf("Failed to generate ext proxy token: %v", err)
}
if err := p.store.SetPairingInfo(ctx, extensionID, PairingInfo{
Device: MobileDevice{
DeviceID: req.DeviceID,
FCMToken: req.FCMToken,
},
PairedAt: time.Now().UTC(),
})
}); err != nil {
return ConfirmPairingResponse{}, err
}
return ConfirmPairingResponse{ProxyToken: mobileProxyToken}, nil
}

View File

@ -1,7 +1,15 @@
package pass
import (
"log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/pass/sign"
"github.com/twofas/2fas-server/config"
httphelpers "github.com/twofas/2fas-server/internal/common/http"
"github.com/twofas/2fas-server/internal/common/recovery"
"github.com/twofas/2fas-server/internal/pass/pairing"
@ -12,8 +20,27 @@ type Server struct {
addr string
}
func NewServer(addr string) *Server {
pairingApp := pairing.NewPairingApp()
func NewServer(cfg config.PassConfig) *Server {
var awsEndpoint *string
if cfg.AWSEndpoint != "" {
awsEndpoint = aws.String(cfg.AWSEndpoint)
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
S3ForcePathStyle: aws.Bool(true),
Endpoint: awsEndpoint,
})
if err != nil {
log.Fatal(err)
}
kmsClient := kms.New(sess)
signSvc, err := sign.NewService(cfg.KMSKeyID, kmsClient)
if err != nil {
log.Fatal(err)
}
pairingApp := pairing.NewPairingApp(signSvc)
proxyApp := pairing.NewProxy()
router := gin.New()
@ -35,7 +62,7 @@ func NewServer(addr string) *Server {
return &Server{
router: router,
addr: addr,
addr: cfg.Addr,
}
}

View File

@ -59,4 +59,5 @@ const (
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
)

View File

@ -46,7 +46,8 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
t.Fatal(err)
}
if err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy); err != nil {
_, err = srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
if err != nil {
t.Fatal(err)
}
}
@ -156,7 +157,7 @@ func TestSignAndVerify(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
token := tc.tokenFn()
err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
_, err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
if err == nil {
t.Fatalf("Expected error %v, got nil", tc.expectedError)
}

View File

@ -3,15 +3,22 @@ package sign
import (
"errors"
"fmt"
"slices"
"github.com/golang-jwt/jwt/v5"
)
var ErrInvalidClaims = errors.New("invalid claims")
type customClaims struct {
ConnectionID string `json:"c_id"`
jwt.RegisteredClaims
}
// CanI establish connection with type tp given claims in token.
func (s Service) CanI(tokenString string, ct ConnectionType) error {
cl := jwt.MapClaims{}
// Returns extension_id from claims if token is valid for given type.
func (s Service) CanI(tokenString string, ct ConnectionType) (string, error) {
cl := customClaims{}
// In Sign we removed `jwtHeader` from JWT before returning it.
// We need to add it again before doing the verification.
@ -27,19 +34,19 @@ func (s Service) CanI(tokenString string, ct ConnectionType) error {
jwt.WithExpirationRequired(),
)
if err != nil {
return fmt.Errorf("failed to parse token: %w", err)
return "", fmt.Errorf("failed to parse token: %w", err)
}
claims, err := token.Claims.GetAudience()
audClaims, err := token.Claims.GetAudience()
if err != nil {
return fmt.Errorf("failed to get claims: %w", err)
return "", fmt.Errorf("failed to get claims: %w", err)
}
for _, aud := range claims {
if aud == string(ct) {
return nil
}
if !slices.Contains(audClaims, string(ct)) {
return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct)
}
return fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct)
if cl.ConnectionID == "" {
return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, "c_id")
}
// TODO: rename connectionID to extensionID.
return cl.ConnectionID, nil
}

View File

@ -43,9 +43,11 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
t.Log(token)
t.Log("Length of the token is", len(token))
if err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy); err != nil {
extensionID, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
if err != nil {
t.Fatal(err)
}
t.Log(extensionID)
}
func TestSignAndVerify(t *testing.T) {
@ -139,7 +141,7 @@ func TestSignAndVerify(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
token := tc.tokenFn()
err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
_, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
if err == nil {
t.Fatalf("Expected error %v, got nil", tc.expectedError)
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
@ -24,7 +25,13 @@ var (
wsDialer = websocket.DefaultDialer
)
const api = "localhost:8082"
func getAPIURL() string {
addr := os.Getenv("PASS_ADDR")
if addr != "" {
return addr
}
return "localhost:8082"
}
func TestPassHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension()
@ -38,15 +45,15 @@ func TestPassHappyFlow(t *testing.T) {
go func() {
defer close(browserExtensionDone)
err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
extProxyToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
if err != nil {
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
return
}
err = proxyWebSocket(
"ws://"+api+"/browser_extension/proxy_to_mobile",
resp.BrowserExtensionPairingToken,
"ws://"+getAPIURL()+"/browser_extension/proxy_to_mobile",
extProxyToken,
"sent from browser extension",
"sent from mobile")
if err != nil {
@ -58,15 +65,15 @@ func TestPassHappyFlow(t *testing.T) {
go func() {
defer close(mobileDone)
err := confirmMobile(resp.ConnectionToken)
mobileProxyToken, err := confirmMobile(resp.ConnectionToken)
if err != nil {
t.Errorf("Mobile: confirm failed: %v", err)
return
}
err = proxyWebSocket(
"ws://"+api+"/mobile/proxy_to_browser_extension",
resp.BrowserExtensionPairingToken,
"ws://"+getAPIURL()+"/mobile/proxy_to_browser_extension",
mobileProxyToken,
"sent from mobile",
"sent from browser extension",
)
@ -75,41 +82,42 @@ func TestPassHappyFlow(t *testing.T) {
return
}
}()
<-browserExtensionDone
<-mobileDone
}
func browserExtensionWaitForConfirm(token string) error {
url := "ws://" + api + "/browser_extension/wait_for_connection"
func browserExtensionWaitForConfirm(token string) (string, error) {
url := "ws://" + getAPIURL() + "/browser_extension/wait_for_connection"
var resp struct {
Status string `json:"status"`
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
Status string `json:"status"`
DeviceID string `json:"device_id"`
}
conn, err := dialWS(url, token)
if err != nil {
return err
return "", err
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(time.Second))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("error reading from connection: %w", err)
return "", fmt.Errorf("error reading from connection: %w", err)
}
if err := json.Unmarshal(message, &resp); err != nil {
return fmt.Errorf("failed to decode message: %w", err)
return "", fmt.Errorf("failed to decode message: %w", err)
}
const expectedStatus = "ok"
if resp.Status != expectedStatus {
return fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
}
return nil
return resp.BrowserExtensionProxyToken, nil
}
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
url := "http://" + api + "/browser_extension/configure"
url := "http://" + getAPIURL() + "/browser_extension/configure"
req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String()))
if err != nil {
@ -138,26 +146,39 @@ func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
return resp, nil
}
func confirmMobile(connectionToken string) error {
url := "http://" + api + "/mobile/confirm"
// confirmMobile confirms pairing and returns mobile proxy token.
func confirmMobile(connectionToken string) (string, error) {
url := "http://" + getAPIURL() + "/mobile/confirm"
req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String()))
if err != nil {
return fmt.Errorf("failed to prepare the reqest: %w", err)
return "", fmt.Errorf("failed to prepare the reqest: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken))
httpResp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to perform the reqest: %w", err)
return "", fmt.Errorf("failed to perform the reqest: %w", err)
}
defer httpResp.Body.Close()
if httpResp.StatusCode > 299 {
return fmt.Errorf("unexpected response: %s", httpResp.Status)
bb, err := io.ReadAll(httpResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read body from response: %w", err)
}
return nil
if httpResp.StatusCode >= 300 {
return "", fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
}
var resp struct {
ProxyToken string `json:"proxy_token"`
}
if err := json.Unmarshal(bb, &resp); err != nil {
return "", fmt.Errorf("failed to decode the response: %w", err)
}
return resp.ProxyToken, nil
}
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
@ -165,7 +186,7 @@ func confirmMobile(connectionToken string) error {
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
conn, err := dialWS(url, token)
if err != nil {
return nil
return err
}
defer conn.Close()
@ -204,7 +225,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) {
},
})
if err != nil {
return nil, fmt.Errorf("failed to dial ws %q: %v", url, err)
return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
}
return conn, nil
}