mirror of
https://github.com/twofas/2fas-server.git
synced 2024-12-04 16:20:13 +01:00
feat: use Sec-WebSocket-Protocol to pass auth on ws
This commit is contained in:
parent
17fb204680
commit
85f0cd1206
@ -11,7 +11,7 @@ import (
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
)
|
||||
|
||||
func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
var req ConfigureBrowserExtensionRequest
|
||||
if err := gCtx.BindJSON(&req); err != nil {
|
||||
@ -33,10 +33,9 @@ func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
@ -49,14 +48,14 @@ func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
|
||||
}
|
||||
}
|
||||
|
||||
func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
@ -84,7 +83,6 @@ func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.Hand
|
||||
|
||||
func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
@ -116,10 +114,9 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func MobileProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@ -63,6 +62,11 @@ type WaitForConnectionResponse struct {
|
||||
|
||||
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
|
||||
log := logging.WithField("extension_id", extID)
|
||||
upgrader, err := wsUpgraderForProtocol(r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
|
||||
@ -146,17 +150,3 @@ func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest,
|
||||
PairedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4 * 1024,
|
||||
WriteBufferSize: 4 * 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
|
||||
|
||||
if allowedOrigin != "" {
|
||||
return r.Header.Get("Origin") == allowedOrigin
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
@ -197,6 +197,11 @@ func (c *client) writePump() {
|
||||
|
||||
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
|
||||
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
|
||||
upgrader, err := wsUpgraderForProtocol(r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
|
||||
@ -228,6 +233,11 @@ func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
|
||||
upgrader, err := wsUpgraderForProtocol(r)
|
||||
if err != nil {
|
||||
logging.Error(err)
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)
|
||||
|
100
internal/pass/pairing/websocket.go
Normal file
100
internal/pass/pairing/websocket.go
Normal file
@ -0,0 +1,100 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const bearerProtocolPrefix = "base64url.bearer.authorization.2pass.io."
|
||||
|
||||
var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol")
|
||||
|
||||
// tokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
|
||||
// It is used because websocket API in browser does not allow to pass headers.
|
||||
// https://github.com/kubernetes/kubernetes/commit/714f97d7baf4975ad3aa47735a868a81a984d1f0
|
||||
//
|
||||
// Client provides a bearer token as a subprotocol in the format
|
||||
// "base64url.bearer.authorization.2pass.io.<base64url-without-padding(bearer-token)>".
|
||||
// This function also modified request header by removing authorization header.
|
||||
// Server according to spec must return at least 1 protocol to client, so this fn checks
|
||||
// if at least 1 protocol is sent, beside authorization header.
|
||||
func tokenFromWSProtocol(req *http.Request) (string, error) {
|
||||
token := ""
|
||||
sawTokenProtocol := false
|
||||
filteredProtocols := []string{}
|
||||
for _, protocolHeader := range req.Header[protocolHeader] {
|
||||
for _, protocol := range strings.Split(protocolHeader, ",") {
|
||||
protocol = strings.TrimSpace(protocol)
|
||||
|
||||
if !strings.HasPrefix(protocol, bearerProtocolPrefix) {
|
||||
filteredProtocols = append(filteredProtocols, protocol)
|
||||
continue
|
||||
}
|
||||
|
||||
if sawTokenProtocol {
|
||||
return "", errors.New("multiple base64.bearer.authorization tokens specified")
|
||||
}
|
||||
sawTokenProtocol = true
|
||||
|
||||
encodedToken := strings.TrimPrefix(protocol, bearerProtocolPrefix)
|
||||
decodedToken, err := base64.RawURLEncoding.DecodeString(encodedToken)
|
||||
if err != nil {
|
||||
return "", errors.New("invalid base64.bearer.authorization token encoding")
|
||||
}
|
||||
if !utf8.Valid(decodedToken) {
|
||||
return "", errors.New("invalid base64.bearer.authorization token")
|
||||
}
|
||||
token = string(decodedToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
return "", errors.New("empty token")
|
||||
}
|
||||
|
||||
// Must pass at least one other subprotocol so that we can remove the one containing the bearer token,
|
||||
// and there is at least one to echo back to the client
|
||||
if len(filteredProtocols) == 0 {
|
||||
return "", errors.New("missing additional subprotocol")
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc6455#section-11.3.4 indicates the Sec-WebSocket-Protocol header may appear multiple times
|
||||
// in a request, and is logically the same as a single Sec-WebSocket-Protocol header field that contains all values
|
||||
req.Header.Set(protocolHeader, strings.Join(filteredProtocols, ","))
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// supportedProtocol2pass is Protocol that will be sent back on upgrade mechanism.
|
||||
const supportedProtocol2pass = "2pass.io"
|
||||
|
||||
var upgrader2pass = websocket.Upgrader{
|
||||
ReadBufferSize: 4 * 1024,
|
||||
WriteBufferSize: 4 * 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
|
||||
|
||||
if allowedOrigin != "" {
|
||||
return r.Header.Get("Origin") == allowedOrigin
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
Subprotocols: []string{supportedProtocol2pass},
|
||||
}
|
||||
|
||||
func wsUpgraderForProtocol(req *http.Request) (websocket.Upgrader, error) {
|
||||
protocols := strings.Split(req.Header.Get(protocolHeader), ",")
|
||||
if slices.Contains(protocols, supportedProtocol2pass) {
|
||||
return upgrader2pass, nil
|
||||
}
|
||||
return websocket.Upgrader{}, fmt.Errorf("upgrader not available for protocols: %v", protocols)
|
||||
}
|
63
internal/pass/pairing/websocket_test.go
Normal file
63
internal/pass/pairing/websocket_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_tokenFromWSProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocolHeader string
|
||||
assertFn func(t *testing.T, token string, err error)
|
||||
}{
|
||||
{
|
||||
name: "valid token with additional subprotocol",
|
||||
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,ws",
|
||||
assertFn: func(t *testing.T, token string, err error) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test1", token)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing header",
|
||||
assertFn: func(t *testing.T, token string, err error) {
|
||||
require.ErrorContains(t, err, "empty token")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid encoding",
|
||||
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdA==",
|
||||
assertFn: func(t *testing.T, token string, err error) {
|
||||
require.ErrorContains(t, err, "invalid base64.bearer.authorization token encoding")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing other protocol",
|
||||
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE",
|
||||
assertFn: func(t *testing.T, token string, err error) {
|
||||
require.ErrorContains(t, err, "missing additional subprotocol")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "double authorization header",
|
||||
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,base64url.bearer.authorization.2pass.io.dGVzdDE",
|
||||
assertFn: func(t *testing.T, token string, err error) {
|
||||
require.ErrorContains(t, err, "multiple base64.bearer.authorization tokens specified")
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "http://localhost", nil)
|
||||
require.NoError(t, err)
|
||||
if protocolHeader != "" {
|
||||
req.Header.Set(protocolHeader, tt.protocolHeader)
|
||||
}
|
||||
got, err := tokenFromWSProtocol(req)
|
||||
tt.assertFn(t, got, err)
|
||||
})
|
||||
}
|
||||
}
|
@ -27,11 +27,11 @@ func NewServer(addr string) *Server {
|
||||
context.Status(200)
|
||||
})
|
||||
|
||||
router.POST("/browser_extension/configure", pairing.BrowserExtensionConfigureHandler(pairingApp))
|
||||
router.GET("/browser_extension/wait_for_connection", pairing.BrowserExtensionWaitForConnHandler(pairingApp))
|
||||
router.GET("/browser_extension/proxy_to_mobile", pairing.BrowserExtensionProxyHandler(pairingApp, proxyApp))
|
||||
router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp))
|
||||
router.GET("/browser_extension/wait_for_connection", pairing.ExtensionWaitForConnWSHandler(pairingApp))
|
||||
router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyApp))
|
||||
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
|
||||
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyHandler(pairingApp, proxyApp))
|
||||
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyApp))
|
||||
|
||||
return &Server{
|
||||
router: router,
|
||||
|
Loading…
Reference in New Issue
Block a user