feat: connect pass with kms (#29)

This commit is contained in:
Tobiasz Heller 2024-01-24 20:57:31 +01:00 committed by GitHub
parent c00c8a4d5b
commit 782e77173d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 217 additions and 72 deletions

4
.env
View File

@ -15,3 +15,7 @@ SECURITY_RATE_LIMIT_BE=100
SECURITY_RATE_LIMIT_MOBILE=100 SECURITY_RATE_LIMIT_MOBILE=100
PASS_ADDR=:8082 PASS_ADDR=:8082
AWS_ACCESS_KEY_ID=test
AWS_SECRET_ACCESS_KEY=test
AWS_ENDPOINT="http://localhost:4566"

View File

@ -32,6 +32,7 @@ tests-e2e: ## run end to end tests
go test ./tests/mobile/... -count=1 go test ./tests/mobile/... -count=1
go test ./tests/support/... -count=1 go test ./tests/support/... -count=1
go test ./tests/system/... -count=1 go test ./tests/system/... -count=1
go test ./tests/pass/... -count=1
vendor-licenses: ## report vendor licenses vendor-licenses: ## report vendor licenses

View File

@ -17,7 +17,7 @@ func main() {
logging.Fatal(err.Error()) logging.Fatal(err.Error())
} }
server := pass.NewServer(cfg.Addr) server := pass.NewServer(cfg)
if err := server.Run(); err != nil { if err := server.Run(); err != nil {
logging.Fatal(err.Error()) logging.Fatal(err.Error())

View File

@ -1,5 +1,7 @@
package config package config
type PassConfig struct { type PassConfig struct {
Addr string `envconfig:"PASS_ADDR" default:":8082"` Addr string `envconfig:"PASS_ADDR" default:":8082"`
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
} }

View File

@ -90,8 +90,14 @@ services:
- '1000' - '1000'
ports: ports:
- "8084:8082" - "8084:8082"
environment:
# overwrite AWS_ENDPOINT from .env file. One in env is used to running app from local also.
AWS_ENDPOINT: http://localstack-main:4566
env_file: env_file:
- .env - .env
depends_on:
localstack:
condition: service_healthy
localstack: localstack:
container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}" container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}"
@ -100,6 +106,12 @@ services:
- "127.0.0.1:4566:4566" - "127.0.0.1:4566:4566"
environment: environment:
- DEBUG=1 - DEBUG=1
healthcheck:
test: >-
curl -s localhost:4566/_localstack/health | grep -q '"kms": "running"'
interval: 5s
timeout: 5s
retries: 5
volumes: volumes:
- "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook - "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook
- "./data/localstack:/var/lib/localstack" - "./data/localstack:/var/lib/localstack"

View File

@ -3,12 +3,17 @@ package pairing
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"github.com/twofas/2fas-server/internal/pass/sign"
) )
// VerifyPairingToken verifies pairing token and returns extension_id // VerifyPairingToken verifies pairing token and returns extension_id
func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) { func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) {
// TODO verify pairing token and take extension from token, this is for debug only. extensionID, err := p.signSvc.CanI(pairingToken, sign.ConnectionTypeBrowserExtensionWait)
extensionID := pairingToken if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID) ok := p.store.ExtensionExists(ctx, extensionID)
if !ok { if !ok {
return "", errors.New("extension is not configured") return "", errors.New("extension is not configured")
@ -16,10 +21,25 @@ func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (
return extensionID, nil return extensionID, nil
} }
// VerifyProxyToken verifies proxy token and returns extension_id // VerifyExtProxyToken verifies proxy token and returns extension_id
func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (string, error) { func (p *Pairing) VerifyExtProxyToken(ctx context.Context, proxyToken string) (string, error) {
// TODO verify proxy token and take extension from token, this is for debug only. extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionProxy)
extensionID := proxyToken if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID)
if !ok {
return "", errors.New("extension is not configured")
}
return extensionID, nil
}
// VerifyMobileProxyToken verifies mobile token and returns extension_id
func (p *Pairing) VerifyMobileProxyToken(ctx context.Context, proxyToken string) (string, error) {
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileProxy)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID) ok := p.store.ExtensionExists(ctx, extensionID)
if !ok { if !ok {
return "", errors.New("extension is not configured") return "", errors.New("extension is not configured")
@ -29,8 +49,10 @@ func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (stri
// VerifyConnectionToken verifies connection token and returns extension_id // VerifyConnectionToken verifies connection token and returns extension_id
func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) { func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) {
// TODO verify proxy token and take extension from token, this is for debug only. extensionID, err := p.signSvc.CanI(connectionToken, sign.ConnectionTypeMobileConfirm)
extensionID := connectionToken if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
ok := p.store.ExtensionExists(ctx, extensionID) ok := p.store.ExtensionExists(ctx, extensionID)
if !ok { if !ok {
return "", errors.New("extension is not configured") return "", errors.New("extension is not configured")

View File

@ -61,7 +61,7 @@ func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFu
gCtx.Status(http.StatusForbidden) gCtx.Status(http.StatusForbidden)
return return
} }
extensionID, err := pairingApp.VerifyProxyToken(gCtx, token) extensionID, err := pairingApp.VerifyExtProxyToken(gCtx, token)
if err != nil { if err != nil {
logging.Errorf("Failed to verify proxy token: %v", err) logging.Errorf("Failed to verify proxy token: %v", err)
gCtx.Status(http.StatusInternalServerError) gCtx.Status(http.StatusInternalServerError)
@ -106,11 +106,13 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
return return
} }
if err := pairingApp.ConfirmPairing(gCtx, req, extensionID); err != nil { resp, err := pairingApp.ConfirmPairing(gCtx, req, extensionID)
if err != nil {
logging.Errorf("Failed to ConfirmPairing: %v", err) logging.Errorf("Failed to ConfirmPairing: %v", err)
gCtx.Status(http.StatusInternalServerError) gCtx.Status(http.StatusInternalServerError)
return return
} }
gCtx.JSON(http.StatusOK, resp)
} }
} }
@ -122,7 +124,7 @@ func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc
gCtx.Status(http.StatusForbidden) gCtx.Status(http.StatusForbidden)
return return
} }
extensionID, err := pairingApp.VerifyConnectionToken(gCtx, token) extensionID, err := pairingApp.VerifyMobileProxyToken(gCtx, token)
if err != nil { if err != nil {
logging.Errorf("Failed to verify connection token: %v", err) logging.Errorf("Failed to verify connection token: %v", err)
gCtx.Status(http.StatusInternalServerError) gCtx.Status(http.StatusInternalServerError)

View File

@ -9,10 +9,12 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/sign"
) )
type Pairing struct { type Pairing struct {
store store store store
signSvc *sign.Service
} }
type store interface { type store interface {
@ -22,12 +24,17 @@ type store interface {
SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error
} }
func NewPairingApp() *Pairing { func NewPairingApp(signService *sign.Service) *Pairing {
return &Pairing{ return &Pairing{
store: NewMemoryStore(), store: NewMemoryStore(),
signSvc: signService,
} }
} }
const (
pairingTokenValidityDuration = 3 * time.Minute
)
type ConfigureBrowserExtensionRequest struct { type ConfigureBrowserExtensionRequest struct {
ExtensionID string `json:"extension_id"` ExtensionID string `json:"extension_id"`
} }
@ -39,12 +46,26 @@ type ConfigureBrowserExtensionResponse struct {
func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) { func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) {
p.store.AddExtension(ctx, req.ExtensionID) p.store.AddExtension(ctx, req.ExtensionID)
// TODO: generate connection token and pairing token.
connectionToken := req.ExtensionID
pairingToken := req.ExtensionID
pairingToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: req.ExtensionID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeBrowserExtensionWait,
})
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to generate pairing token: %v", err)
}
mobileToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: req.ExtensionID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeMobileConfirm,
})
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("Failed to generate mobile confirm token: %v", err)
}
return ConfigureBrowserExtensionResponse{ return ConfigureBrowserExtensionResponse{
ConnectionToken: connectionToken, ConnectionToken: mobileToken,
BrowserExtensionPairingToken: pairingToken, BrowserExtensionPairingToken: pairingToken,
}, nil }, nil
} }
@ -119,10 +140,17 @@ func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logr
} }
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error { func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
// generate token here extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: extID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
})
if err != nil {
return fmt.Errorf("Failed to generate ext proxy token: %v", err)
}
if err := conn.WriteJSON(WaitForConnectionResponse{ if err := conn.WriteJSON(WaitForConnectionResponse{
// TODO: replace with real token. BrowserExtensionProxyToken: extProxyToken,
BrowserExtensionProxyToken: extID,
Status: "ok", Status: "ok",
DeviceID: deviceID, DeviceID: deviceID,
}); err != nil { }); err != nil {
@ -141,12 +169,27 @@ type ConfirmPairingRequest struct {
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
} }
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) error { type ConfirmPairingResponse struct {
return p.store.SetPairingInfo(ctx, extensionID, PairingInfo{ ProxyToken string `json:"proxy_token"`
}
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) (ConfirmPairingResponse, error) {
mobileProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: extensionID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeMobileProxy,
})
if err != nil {
return ConfirmPairingResponse{}, fmt.Errorf("Failed to generate ext proxy token: %v", err)
}
if err := p.store.SetPairingInfo(ctx, extensionID, PairingInfo{
Device: MobileDevice{ Device: MobileDevice{
DeviceID: req.DeviceID, DeviceID: req.DeviceID,
FCMToken: req.FCMToken, FCMToken: req.FCMToken,
}, },
PairedAt: time.Now().UTC(), PairedAt: time.Now().UTC(),
}) }); err != nil {
return ConfirmPairingResponse{}, err
}
return ConfirmPairingResponse{ProxyToken: mobileProxyToken}, nil
} }

View File

@ -1,7 +1,15 @@
package pass package pass
import ( import (
"log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/pass/sign"
"github.com/twofas/2fas-server/config"
httphelpers "github.com/twofas/2fas-server/internal/common/http" httphelpers "github.com/twofas/2fas-server/internal/common/http"
"github.com/twofas/2fas-server/internal/common/recovery" "github.com/twofas/2fas-server/internal/common/recovery"
"github.com/twofas/2fas-server/internal/pass/pairing" "github.com/twofas/2fas-server/internal/pass/pairing"
@ -12,8 +20,27 @@ type Server struct {
addr string addr string
} }
func NewServer(addr string) *Server { func NewServer(cfg config.PassConfig) *Server {
pairingApp := pairing.NewPairingApp() var awsEndpoint *string
if cfg.AWSEndpoint != "" {
awsEndpoint = aws.String(cfg.AWSEndpoint)
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("us-east-1"),
S3ForcePathStyle: aws.Bool(true),
Endpoint: awsEndpoint,
})
if err != nil {
log.Fatal(err)
}
kmsClient := kms.New(sess)
signSvc, err := sign.NewService(cfg.KMSKeyID, kmsClient)
if err != nil {
log.Fatal(err)
}
pairingApp := pairing.NewPairingApp(signSvc)
proxyApp := pairing.NewProxy() proxyApp := pairing.NewProxy()
router := gin.New() router := gin.New()
@ -35,7 +62,7 @@ func NewServer(addr string) *Server {
return &Server{ return &Server{
router: router, router: router,
addr: addr, addr: cfg.Addr,
} }
} }

View File

@ -59,4 +59,5 @@ const (
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait" ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy" ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy" ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
) )

View File

@ -46,7 +46,8 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy); err != nil { _, err = srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -156,7 +157,7 @@ func TestSignAndVerify(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
token := tc.tokenFn() token := tc.tokenFn()
err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy) _, err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
if err == nil { if err == nil {
t.Fatalf("Expected error %v, got nil", tc.expectedError) t.Fatalf("Expected error %v, got nil", tc.expectedError)
} }

View File

@ -3,15 +3,22 @@ package sign
import ( import (
"errors" "errors"
"fmt" "fmt"
"slices"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
var ErrInvalidClaims = errors.New("invalid claims") var ErrInvalidClaims = errors.New("invalid claims")
type customClaims struct {
ConnectionID string `json:"c_id"`
jwt.RegisteredClaims
}
// CanI establish connection with type tp given claims in token. // CanI establish connection with type tp given claims in token.
func (s Service) CanI(tokenString string, ct ConnectionType) error { // Returns extension_id from claims if token is valid for given type.
cl := jwt.MapClaims{} func (s Service) CanI(tokenString string, ct ConnectionType) (string, error) {
cl := customClaims{}
// In Sign we removed `jwtHeader` from JWT before returning it. // In Sign we removed `jwtHeader` from JWT before returning it.
// We need to add it again before doing the verification. // We need to add it again before doing the verification.
@ -27,19 +34,19 @@ func (s Service) CanI(tokenString string, ct ConnectionType) error {
jwt.WithExpirationRequired(), jwt.WithExpirationRequired(),
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to parse token: %w", err) return "", fmt.Errorf("failed to parse token: %w", err)
} }
claims, err := token.Claims.GetAudience() audClaims, err := token.Claims.GetAudience()
if err != nil { if err != nil {
return fmt.Errorf("failed to get claims: %w", err) return "", fmt.Errorf("failed to get claims: %w", err)
} }
if !slices.Contains(audClaims, string(ct)) {
for _, aud := range claims { return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct)
if aud == string(ct) {
return nil
}
} }
if cl.ConnectionID == "" {
return fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct) return "", fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, "c_id")
}
// TODO: rename connectionID to extensionID.
return cl.ConnectionID, nil
} }

View File

@ -43,9 +43,11 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
t.Log(token) t.Log(token)
t.Log("Length of the token is", len(token)) t.Log("Length of the token is", len(token))
if err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy); err != nil { extensionID, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log(extensionID)
} }
func TestSignAndVerify(t *testing.T) { func TestSignAndVerify(t *testing.T) {
@ -139,7 +141,7 @@ func TestSignAndVerify(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
token := tc.tokenFn() token := tc.tokenFn()
err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy) _, err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
if err == nil { if err == nil {
t.Fatalf("Expected error %v, got nil", tc.expectedError) t.Fatalf("Expected error %v, got nil", tc.expectedError)
} }

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"testing" "testing"
"time" "time"
@ -24,7 +25,13 @@ var (
wsDialer = websocket.DefaultDialer wsDialer = websocket.DefaultDialer
) )
const api = "localhost:8082" func getAPIURL() string {
addr := os.Getenv("PASS_ADDR")
if addr != "" {
return addr
}
return "localhost:8082"
}
func TestPassHappyFlow(t *testing.T) { func TestPassHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension() resp, err := configureBrowserExtension()
@ -38,15 +45,15 @@ func TestPassHappyFlow(t *testing.T) {
go func() { go func() {
defer close(browserExtensionDone) defer close(browserExtensionDone)
err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken) extProxyToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
if err != nil { if err != nil {
t.Errorf("Error when Browser Extension waited for confirm: %v", err) t.Errorf("Error when Browser Extension waited for confirm: %v", err)
return return
} }
err = proxyWebSocket( err = proxyWebSocket(
"ws://"+api+"/browser_extension/proxy_to_mobile", "ws://"+getAPIURL()+"/browser_extension/proxy_to_mobile",
resp.BrowserExtensionPairingToken, extProxyToken,
"sent from browser extension", "sent from browser extension",
"sent from mobile") "sent from mobile")
if err != nil { if err != nil {
@ -58,15 +65,15 @@ func TestPassHappyFlow(t *testing.T) {
go func() { go func() {
defer close(mobileDone) defer close(mobileDone)
err := confirmMobile(resp.ConnectionToken) mobileProxyToken, err := confirmMobile(resp.ConnectionToken)
if err != nil { if err != nil {
t.Errorf("Mobile: confirm failed: %v", err) t.Errorf("Mobile: confirm failed: %v", err)
return return
} }
err = proxyWebSocket( err = proxyWebSocket(
"ws://"+api+"/mobile/proxy_to_browser_extension", "ws://"+getAPIURL()+"/mobile/proxy_to_browser_extension",
resp.BrowserExtensionPairingToken, mobileProxyToken,
"sent from mobile", "sent from mobile",
"sent from browser extension", "sent from browser extension",
) )
@ -75,41 +82,42 @@ func TestPassHappyFlow(t *testing.T) {
return return
} }
}() }()
<-browserExtensionDone <-browserExtensionDone
<-mobileDone <-mobileDone
} }
func browserExtensionWaitForConfirm(token string) error { func browserExtensionWaitForConfirm(token string) (string, error) {
url := "ws://" + api + "/browser_extension/wait_for_connection" url := "ws://" + getAPIURL() + "/browser_extension/wait_for_connection"
var resp struct { var resp struct {
Status string `json:"status"` BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
Status string `json:"status"`
DeviceID string `json:"device_id"`
} }
conn, err := dialWS(url, token) conn, err := dialWS(url, token)
if err != nil { if err != nil {
return err return "", err
} }
defer conn.Close() defer conn.Close()
conn.SetReadDeadline(time.Now().Add(time.Second)) conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, message, err := conn.ReadMessage() _, message, err := conn.ReadMessage()
if err != nil { if err != nil {
return fmt.Errorf("error reading from connection: %w", err) return "", fmt.Errorf("error reading from connection: %w", err)
} }
if err := json.Unmarshal(message, &resp); err != nil { if err := json.Unmarshal(message, &resp); err != nil {
return fmt.Errorf("failed to decode message: %w", err) return "", fmt.Errorf("failed to decode message: %w", err)
} }
const expectedStatus = "ok" const expectedStatus = "ok"
if resp.Status != expectedStatus { if resp.Status != expectedStatus {
return fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus) return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
} }
return nil return resp.BrowserExtensionProxyToken, nil
} }
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) { func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
url := "http://" + api + "/browser_extension/configure" url := "http://" + getAPIURL() + "/browser_extension/configure"
req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String())) req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String()))
if err != nil { if err != nil {
@ -138,26 +146,39 @@ func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
return resp, nil return resp, nil
} }
func confirmMobile(connectionToken string) error { // confirmMobile confirms pairing and returns mobile proxy token.
url := "http://" + api + "/mobile/confirm" func confirmMobile(connectionToken string) (string, error) {
url := "http://" + getAPIURL() + "/mobile/confirm"
req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String())) req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String()))
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare the reqest: %w", err) return "", fmt.Errorf("failed to prepare the reqest: %w", err)
} }
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken))
httpResp, err := httpClient.Do(req) httpResp, err := httpClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to perform the reqest: %w", err) return "", fmt.Errorf("failed to perform the reqest: %w", err)
} }
defer httpResp.Body.Close() defer httpResp.Body.Close()
if httpResp.StatusCode > 299 { bb, err := io.ReadAll(httpResp.Body)
return fmt.Errorf("unexpected response: %s", httpResp.Status) if err != nil {
return "", fmt.Errorf("failed to read body from response: %w", err)
} }
return nil if httpResp.StatusCode >= 300 {
return "", fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
}
var resp struct {
ProxyToken string `json:"proxy_token"`
}
if err := json.Unmarshal(bb, &resp); err != nil {
return "", fmt.Errorf("failed to decode the response: %w", err)
}
return resp.ProxyToken, nil
} }
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and // proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
@ -165,7 +186,7 @@ func confirmMobile(connectionToken string) error {
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error { func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
conn, err := dialWS(url, token) conn, err := dialWS(url, token)
if err != nil { if err != nil {
return nil return err
} }
defer conn.Close() defer conn.Close()
@ -204,7 +225,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) {
}, },
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial ws %q: %v", url, err) return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
} }
return conn, nil return conn, nil
} }