diff --git a/internal/pass/pairing/handlers.go b/internal/pass/pairing/handlers.go index 4a08f53..8322f8f 100644 --- a/internal/pass/pairing/handlers.go +++ b/internal/pass/pairing/handlers.go @@ -11,7 +11,7 @@ import ( "github.com/twofas/2fas-server/internal/common/logging" ) -func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc { +func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc { return func(gCtx *gin.Context) { var req ConfigureBrowserExtensionRequest if err := gCtx.BindJSON(&req); err != nil { @@ -33,10 +33,9 @@ func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc { } } -func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc { +func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc { return func(gCtx *gin.Context) { - // TODO: consider moving auth to middleware. - token, err := tokenFromRequest(gCtx) + token, err := tokenFromWSProtocol(gCtx.Request) if err != nil { logging.Errorf("Failed to get token from request: %v", err) gCtx.Status(http.StatusForbidden) @@ -49,14 +48,14 @@ func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc { gCtx.Status(http.StatusInternalServerError) return } + pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID) } } -func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc { +func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc { return func(gCtx *gin.Context) { - // TODO: consider moving auth to middleware. - token, err := tokenFromRequest(gCtx) + token, err := tokenFromWSProtocol(gCtx.Request) if err != nil { logging.Errorf("Failed to get token from request: %v", err) gCtx.Status(http.StatusForbidden) @@ -84,7 +83,6 @@ func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.Hand func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc { return func(gCtx *gin.Context) { - // TODO: consider moving auth to middleware. token, err := tokenFromRequest(gCtx) if err != nil { logging.Errorf("Failed to get token from request: %v", err) @@ -116,10 +114,9 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc { } } -func MobileProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc { +func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc { return func(gCtx *gin.Context) { - // TODO: consider moving auth to middleware. - token, err := tokenFromRequest(gCtx) + token, err := tokenFromWSProtocol(gCtx.Request) if err != nil { logging.Errorf("Failed to get token from request: %v", err) gCtx.Status(http.StatusForbidden) diff --git a/internal/pass/pairing/pairing.go b/internal/pass/pairing/pairing.go index 1a5e6f3..3006f13 100644 --- a/internal/pass/pairing/pairing.go +++ b/internal/pass/pairing/pairing.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "os" "time" "github.com/gorilla/websocket" @@ -63,6 +62,11 @@ type WaitForConnectionResponse struct { func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) { log := logging.WithField("extension_id", extID) + upgrader, err := wsUpgraderForProtocol(r) + if err != nil { + log.Error(err) + return + } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("Failed to upgrade on ServePairingWS: %v", err) @@ -146,17 +150,3 @@ func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, PairedAt: time.Now().UTC(), }) } - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 4 * 1024, - WriteBufferSize: 4 * 1024, - CheckOrigin: func(r *http.Request) bool { - allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN") - - if allowedOrigin != "" { - return r.Header.Get("Origin") == allowedOrigin - } - - return true - }, -} diff --git a/internal/pass/pairing/proxy.go b/internal/pass/pairing/proxy.go index 3d82199..5c308c4 100644 --- a/internal/pass/pairing/proxy.go +++ b/internal/pass/pairing/proxy.go @@ -197,6 +197,11 @@ func (c *client) writePump() { func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) { log := logging.WithField("extension_id", extID).WithField("device_id", deviceID) + upgrader, err := wsUpgraderForProtocol(r) + if err != nil { + log.Error(err) + return + } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err) @@ -228,6 +233,11 @@ func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Req } func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) { + upgrader, err := wsUpgraderForProtocol(r) + if err != nil { + logging.Error(err) + return + } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err) diff --git a/internal/pass/pairing/websocket.go b/internal/pass/pairing/websocket.go new file mode 100644 index 0000000..5a5b3d3 --- /dev/null +++ b/internal/pass/pairing/websocket.go @@ -0,0 +1,100 @@ +package pairing + +import ( + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/textproto" + "os" + "slices" + "strings" + "unicode/utf8" + + "github.com/gorilla/websocket" +) + +const bearerProtocolPrefix = "base64url.bearer.authorization.2pass.io." + +var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol") + +// tokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header. +// It is used because websocket API in browser does not allow to pass headers. +// https://github.com/kubernetes/kubernetes/commit/714f97d7baf4975ad3aa47735a868a81a984d1f0 +// +// Client provides a bearer token as a subprotocol in the format +// "base64url.bearer.authorization.2pass.io.". +// This function also modified request header by removing authorization header. +// Server according to spec must return at least 1 protocol to client, so this fn checks +// if at least 1 protocol is sent, beside authorization header. +func tokenFromWSProtocol(req *http.Request) (string, error) { + token := "" + sawTokenProtocol := false + filteredProtocols := []string{} + for _, protocolHeader := range req.Header[protocolHeader] { + for _, protocol := range strings.Split(protocolHeader, ",") { + protocol = strings.TrimSpace(protocol) + + if !strings.HasPrefix(protocol, bearerProtocolPrefix) { + filteredProtocols = append(filteredProtocols, protocol) + continue + } + + if sawTokenProtocol { + return "", errors.New("multiple base64.bearer.authorization tokens specified") + } + sawTokenProtocol = true + + encodedToken := strings.TrimPrefix(protocol, bearerProtocolPrefix) + decodedToken, err := base64.RawURLEncoding.DecodeString(encodedToken) + if err != nil { + return "", errors.New("invalid base64.bearer.authorization token encoding") + } + if !utf8.Valid(decodedToken) { + return "", errors.New("invalid base64.bearer.authorization token") + } + token = string(decodedToken) + } + } + + if len(token) == 0 { + return "", errors.New("empty token") + } + + // Must pass at least one other subprotocol so that we can remove the one containing the bearer token, + // and there is at least one to echo back to the client + if len(filteredProtocols) == 0 { + return "", errors.New("missing additional subprotocol") + } + + // https://tools.ietf.org/html/rfc6455#section-11.3.4 indicates the Sec-WebSocket-Protocol header may appear multiple times + // in a request, and is logically the same as a single Sec-WebSocket-Protocol header field that contains all values + req.Header.Set(protocolHeader, strings.Join(filteredProtocols, ",")) + return token, nil +} + +// supportedProtocol2pass is Protocol that will be sent back on upgrade mechanism. +const supportedProtocol2pass = "2pass.io" + +var upgrader2pass = websocket.Upgrader{ + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + CheckOrigin: func(r *http.Request) bool { + allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN") + + if allowedOrigin != "" { + return r.Header.Get("Origin") == allowedOrigin + } + + return true + }, + Subprotocols: []string{supportedProtocol2pass}, +} + +func wsUpgraderForProtocol(req *http.Request) (websocket.Upgrader, error) { + protocols := strings.Split(req.Header.Get(protocolHeader), ",") + if slices.Contains(protocols, supportedProtocol2pass) { + return upgrader2pass, nil + } + return websocket.Upgrader{}, fmt.Errorf("upgrader not available for protocols: %v", protocols) +} diff --git a/internal/pass/pairing/websocket_test.go b/internal/pass/pairing/websocket_test.go new file mode 100644 index 0000000..4c3351c --- /dev/null +++ b/internal/pass/pairing/websocket_test.go @@ -0,0 +1,63 @@ +package pairing + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_tokenFromWSProtocol(t *testing.T) { + tests := []struct { + name string + protocolHeader string + assertFn func(t *testing.T, token string, err error) + }{ + { + name: "valid token with additional subprotocol", + protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,ws", + assertFn: func(t *testing.T, token string, err error) { + require.NoError(t, err) + require.Equal(t, "test1", token) + }, + }, + { + name: "missing header", + assertFn: func(t *testing.T, token string, err error) { + require.ErrorContains(t, err, "empty token") + }, + }, + { + name: "invalid encoding", + protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdA==", + assertFn: func(t *testing.T, token string, err error) { + require.ErrorContains(t, err, "invalid base64.bearer.authorization token encoding") + }, + }, + { + name: "missing other protocol", + protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE", + assertFn: func(t *testing.T, token string, err error) { + require.ErrorContains(t, err, "missing additional subprotocol") + }, + }, + { + name: "double authorization header", + protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,base64url.bearer.authorization.2pass.io.dGVzdDE", + assertFn: func(t *testing.T, token string, err error) { + require.ErrorContains(t, err, "multiple base64.bearer.authorization tokens specified") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "http://localhost", nil) + require.NoError(t, err) + if protocolHeader != "" { + req.Header.Set(protocolHeader, tt.protocolHeader) + } + got, err := tokenFromWSProtocol(req) + tt.assertFn(t, got, err) + }) + } +} diff --git a/internal/pass/server.go b/internal/pass/server.go index 7ee50e9..9018eb1 100644 --- a/internal/pass/server.go +++ b/internal/pass/server.go @@ -27,11 +27,11 @@ func NewServer(addr string) *Server { context.Status(200) }) - router.POST("/browser_extension/configure", pairing.BrowserExtensionConfigureHandler(pairingApp)) - router.GET("/browser_extension/wait_for_connection", pairing.BrowserExtensionWaitForConnHandler(pairingApp)) - router.GET("/browser_extension/proxy_to_mobile", pairing.BrowserExtensionProxyHandler(pairingApp, proxyApp)) + router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp)) + router.GET("/browser_extension/wait_for_connection", pairing.ExtensionWaitForConnWSHandler(pairingApp)) + router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyApp)) router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp)) - router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyHandler(pairingApp, proxyApp)) + router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyApp)) return &Server{ router: router,