feat: use Sec-WebSocket-Protocol to pass auth on ws

This commit is contained in:
Tobiasz Heller 2024-01-20 21:21:46 +01:00
parent 17fb204680
commit 85f0cd1206
6 changed files with 190 additions and 30 deletions

View File

@ -11,7 +11,7 @@ import (
"github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/logging"
) )
func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc { func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
return func(gCtx *gin.Context) { return func(gCtx *gin.Context) {
var req ConfigureBrowserExtensionRequest var req ConfigureBrowserExtensionRequest
if err := gCtx.BindJSON(&req); err != nil { if err := gCtx.BindJSON(&req); err != nil {
@ -33,10 +33,9 @@ func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
} }
} }
func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc { func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc {
return func(gCtx *gin.Context) { return func(gCtx *gin.Context) {
// TODO: consider moving auth to middleware. token, err := tokenFromWSProtocol(gCtx.Request)
token, err := tokenFromRequest(gCtx)
if err != nil { if err != nil {
logging.Errorf("Failed to get token from request: %v", err) logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden) gCtx.Status(http.StatusForbidden)
@ -49,14 +48,14 @@ func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc {
gCtx.Status(http.StatusInternalServerError) gCtx.Status(http.StatusInternalServerError)
return return
} }
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID) pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
} }
} }
func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc { func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
return func(gCtx *gin.Context) { return func(gCtx *gin.Context) {
// TODO: consider moving auth to middleware. token, err := tokenFromWSProtocol(gCtx.Request)
token, err := tokenFromRequest(gCtx)
if err != nil { if err != nil {
logging.Errorf("Failed to get token from request: %v", err) logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden) gCtx.Status(http.StatusForbidden)
@ -84,7 +83,6 @@ func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.Hand
func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc { func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
return func(gCtx *gin.Context) { return func(gCtx *gin.Context) {
// TODO: consider moving auth to middleware.
token, err := tokenFromRequest(gCtx) token, err := tokenFromRequest(gCtx)
if err != nil { if err != nil {
logging.Errorf("Failed to get token from request: %v", err) logging.Errorf("Failed to get token from request: %v", err)
@ -116,10 +114,9 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
} }
} }
func MobileProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc { func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
return func(gCtx *gin.Context) { return func(gCtx *gin.Context) {
// TODO: consider moving auth to middleware. token, err := tokenFromWSProtocol(gCtx.Request)
token, err := tokenFromRequest(gCtx)
if err != nil { if err != nil {
logging.Errorf("Failed to get token from request: %v", err) logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden) gCtx.Status(http.StatusForbidden)

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"os"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -63,6 +62,11 @@ type WaitForConnectionResponse struct {
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) { func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
log := logging.WithField("extension_id", extID) log := logging.WithField("extension_id", extID)
upgrader, err := wsUpgraderForProtocol(r)
if err != nil {
log.Error(err)
return
}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Errorf("Failed to upgrade on ServePairingWS: %v", err) log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
@ -146,17 +150,3 @@ func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest,
PairedAt: time.Now().UTC(), PairedAt: time.Now().UTC(),
}) })
} }
var upgrader = websocket.Upgrader{
ReadBufferSize: 4 * 1024,
WriteBufferSize: 4 * 1024,
CheckOrigin: func(r *http.Request) bool {
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
if allowedOrigin != "" {
return r.Header.Get("Origin") == allowedOrigin
}
return true
},
}

View File

@ -197,6 +197,11 @@ func (c *client) writePump() {
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) { func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID) log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
upgrader, err := wsUpgraderForProtocol(r)
if err != nil {
log.Error(err)
return
}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err) log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
@ -228,6 +233,11 @@ func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Req
} }
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) { func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
upgrader, err := wsUpgraderForProtocol(r)
if err != nil {
logging.Error(err)
return
}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err) logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)

View File

@ -0,0 +1,100 @@
package pairing
import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/textproto"
"os"
"slices"
"strings"
"unicode/utf8"
"github.com/gorilla/websocket"
)
const bearerProtocolPrefix = "base64url.bearer.authorization.2pass.io."
var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol")
// tokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
// It is used because websocket API in browser does not allow to pass headers.
// https://github.com/kubernetes/kubernetes/commit/714f97d7baf4975ad3aa47735a868a81a984d1f0
//
// Client provides a bearer token as a subprotocol in the format
// "base64url.bearer.authorization.2pass.io.<base64url-without-padding(bearer-token)>".
// This function also modified request header by removing authorization header.
// Server according to spec must return at least 1 protocol to client, so this fn checks
// if at least 1 protocol is sent, beside authorization header.
func tokenFromWSProtocol(req *http.Request) (string, error) {
token := ""
sawTokenProtocol := false
filteredProtocols := []string{}
for _, protocolHeader := range req.Header[protocolHeader] {
for _, protocol := range strings.Split(protocolHeader, ",") {
protocol = strings.TrimSpace(protocol)
if !strings.HasPrefix(protocol, bearerProtocolPrefix) {
filteredProtocols = append(filteredProtocols, protocol)
continue
}
if sawTokenProtocol {
return "", errors.New("multiple base64.bearer.authorization tokens specified")
}
sawTokenProtocol = true
encodedToken := strings.TrimPrefix(protocol, bearerProtocolPrefix)
decodedToken, err := base64.RawURLEncoding.DecodeString(encodedToken)
if err != nil {
return "", errors.New("invalid base64.bearer.authorization token encoding")
}
if !utf8.Valid(decodedToken) {
return "", errors.New("invalid base64.bearer.authorization token")
}
token = string(decodedToken)
}
}
if len(token) == 0 {
return "", errors.New("empty token")
}
// Must pass at least one other subprotocol so that we can remove the one containing the bearer token,
// and there is at least one to echo back to the client
if len(filteredProtocols) == 0 {
return "", errors.New("missing additional subprotocol")
}
// https://tools.ietf.org/html/rfc6455#section-11.3.4 indicates the Sec-WebSocket-Protocol header may appear multiple times
// in a request, and is logically the same as a single Sec-WebSocket-Protocol header field that contains all values
req.Header.Set(protocolHeader, strings.Join(filteredProtocols, ","))
return token, nil
}
// supportedProtocol2pass is Protocol that will be sent back on upgrade mechanism.
const supportedProtocol2pass = "2pass.io"
var upgrader2pass = websocket.Upgrader{
ReadBufferSize: 4 * 1024,
WriteBufferSize: 4 * 1024,
CheckOrigin: func(r *http.Request) bool {
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
if allowedOrigin != "" {
return r.Header.Get("Origin") == allowedOrigin
}
return true
},
Subprotocols: []string{supportedProtocol2pass},
}
func wsUpgraderForProtocol(req *http.Request) (websocket.Upgrader, error) {
protocols := strings.Split(req.Header.Get(protocolHeader), ",")
if slices.Contains(protocols, supportedProtocol2pass) {
return upgrader2pass, nil
}
return websocket.Upgrader{}, fmt.Errorf("upgrader not available for protocols: %v", protocols)
}

View File

@ -0,0 +1,63 @@
package pairing
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func Test_tokenFromWSProtocol(t *testing.T) {
tests := []struct {
name string
protocolHeader string
assertFn func(t *testing.T, token string, err error)
}{
{
name: "valid token with additional subprotocol",
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,ws",
assertFn: func(t *testing.T, token string, err error) {
require.NoError(t, err)
require.Equal(t, "test1", token)
},
},
{
name: "missing header",
assertFn: func(t *testing.T, token string, err error) {
require.ErrorContains(t, err, "empty token")
},
},
{
name: "invalid encoding",
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdA==",
assertFn: func(t *testing.T, token string, err error) {
require.ErrorContains(t, err, "invalid base64.bearer.authorization token encoding")
},
},
{
name: "missing other protocol",
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE",
assertFn: func(t *testing.T, token string, err error) {
require.ErrorContains(t, err, "missing additional subprotocol")
},
},
{
name: "double authorization header",
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,base64url.bearer.authorization.2pass.io.dGVzdDE",
assertFn: func(t *testing.T, token string, err error) {
require.ErrorContains(t, err, "multiple base64.bearer.authorization tokens specified")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost", nil)
require.NoError(t, err)
if protocolHeader != "" {
req.Header.Set(protocolHeader, tt.protocolHeader)
}
got, err := tokenFromWSProtocol(req)
tt.assertFn(t, got, err)
})
}
}

View File

@ -27,11 +27,11 @@ func NewServer(addr string) *Server {
context.Status(200) context.Status(200)
}) })
router.POST("/browser_extension/configure", pairing.BrowserExtensionConfigureHandler(pairingApp)) router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp))
router.GET("/browser_extension/wait_for_connection", pairing.BrowserExtensionWaitForConnHandler(pairingApp)) router.GET("/browser_extension/wait_for_connection", pairing.ExtensionWaitForConnWSHandler(pairingApp))
router.GET("/browser_extension/proxy_to_mobile", pairing.BrowserExtensionProxyHandler(pairingApp, proxyApp)) router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyApp))
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp)) router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyHandler(pairingApp, proxyApp)) router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyApp))
return &Server{ return &Server{
router: router, router: router,