mirror of
https://github.com/twofas/2fas-server.git
synced 2025-01-08 15:29:36 +01:00
feat: use Sec-WebSocket-Protocol to pass auth on ws (#26)
This commit is contained in:
parent
17fb204680
commit
417d7d25b2
@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/twofas/2fas-server/internal/common/logging"
|
"github.com/twofas/2fas-server/internal/common/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||||
return func(gCtx *gin.Context) {
|
return func(gCtx *gin.Context) {
|
||||||
var req ConfigureBrowserExtensionRequest
|
var req ConfigureBrowserExtensionRequest
|
||||||
if err := gCtx.BindJSON(&req); err != nil {
|
if err := gCtx.BindJSON(&req); err != nil {
|
||||||
@ -33,10 +33,9 @@ func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc {
|
func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||||
return func(gCtx *gin.Context) {
|
return func(gCtx *gin.Context) {
|
||||||
// TODO: consider moving auth to middleware.
|
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||||
token, err := tokenFromRequest(gCtx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to get token from request: %v", err)
|
logging.Errorf("Failed to get token from request: %v", err)
|
||||||
gCtx.Status(http.StatusForbidden)
|
gCtx.Status(http.StatusForbidden)
|
||||||
@ -49,14 +48,14 @@ func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc {
|
|||||||
gCtx.Status(http.StatusInternalServerError)
|
gCtx.Status(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
|
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||||
return func(gCtx *gin.Context) {
|
return func(gCtx *gin.Context) {
|
||||||
// TODO: consider moving auth to middleware.
|
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||||
token, err := tokenFromRequest(gCtx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to get token from request: %v", err)
|
logging.Errorf("Failed to get token from request: %v", err)
|
||||||
gCtx.Status(http.StatusForbidden)
|
gCtx.Status(http.StatusForbidden)
|
||||||
@ -84,7 +83,6 @@ func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.Hand
|
|||||||
|
|
||||||
func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||||
return func(gCtx *gin.Context) {
|
return func(gCtx *gin.Context) {
|
||||||
// TODO: consider moving auth to middleware.
|
|
||||||
token, err := tokenFromRequest(gCtx)
|
token, err := tokenFromRequest(gCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to get token from request: %v", err)
|
logging.Errorf("Failed to get token from request: %v", err)
|
||||||
@ -116,10 +114,9 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func MobileProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||||
return func(gCtx *gin.Context) {
|
return func(gCtx *gin.Context) {
|
||||||
// TODO: consider moving auth to middleware.
|
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||||
token, err := tokenFromRequest(gCtx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to get token from request: %v", err)
|
logging.Errorf("Failed to get token from request: %v", err)
|
||||||
gCtx.Status(http.StatusForbidden)
|
gCtx.Status(http.StatusForbidden)
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@ -63,6 +62,11 @@ type WaitForConnectionResponse struct {
|
|||||||
|
|
||||||
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
|
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
|
||||||
log := logging.WithField("extension_id", extID)
|
log := logging.WithField("extension_id", extID)
|
||||||
|
upgrader, err := wsUpgraderForProtocol(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
|
log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
|
||||||
@ -146,17 +150,3 @@ func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest,
|
|||||||
PairedAt: time.Now().UTC(),
|
PairedAt: time.Now().UTC(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 4 * 1024,
|
|
||||||
WriteBufferSize: 4 * 1024,
|
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
|
||||||
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
|
|
||||||
|
|
||||||
if allowedOrigin != "" {
|
|
||||||
return r.Header.Get("Origin") == allowedOrigin
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
@ -197,6 +197,11 @@ func (c *client) writePump() {
|
|||||||
|
|
||||||
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
|
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
|
||||||
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
|
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
|
||||||
|
upgrader, err := wsUpgraderForProtocol(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
|
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
|
||||||
@ -228,6 +233,11 @@ func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
|
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
|
||||||
|
upgrader, err := wsUpgraderForProtocol(r)
|
||||||
|
if err != nil {
|
||||||
|
logging.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)
|
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)
|
||||||
|
100
internal/pass/pairing/websocket.go
Normal file
100
internal/pass/pairing/websocket.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package pairing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const bearerProtocolPrefix = "base64url.bearer.authorization.2pass.io."
|
||||||
|
|
||||||
|
var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol")
|
||||||
|
|
||||||
|
// tokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
|
||||||
|
// It is used because websocket API in browser does not allow to pass headers.
|
||||||
|
// https://github.com/kubernetes/kubernetes/commit/714f97d7baf4975ad3aa47735a868a81a984d1f0
|
||||||
|
//
|
||||||
|
// Client provides a bearer token as a subprotocol in the format
|
||||||
|
// "base64url.bearer.authorization.2pass.io.<base64url-without-padding(bearer-token)>".
|
||||||
|
// This function also modified request header by removing authorization header.
|
||||||
|
// Server according to spec must return at least 1 protocol to client, so this fn checks
|
||||||
|
// if at least 1 protocol is sent, beside authorization header.
|
||||||
|
func tokenFromWSProtocol(req *http.Request) (string, error) {
|
||||||
|
token := ""
|
||||||
|
sawTokenProtocol := false
|
||||||
|
filteredProtocols := []string{}
|
||||||
|
for _, protocolHeader := range req.Header[protocolHeader] {
|
||||||
|
for _, protocol := range strings.Split(protocolHeader, ",") {
|
||||||
|
protocol = strings.TrimSpace(protocol)
|
||||||
|
|
||||||
|
if !strings.HasPrefix(protocol, bearerProtocolPrefix) {
|
||||||
|
filteredProtocols = append(filteredProtocols, protocol)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if sawTokenProtocol {
|
||||||
|
return "", errors.New("multiple base64.bearer.authorization tokens specified")
|
||||||
|
}
|
||||||
|
sawTokenProtocol = true
|
||||||
|
|
||||||
|
encodedToken := strings.TrimPrefix(protocol, bearerProtocolPrefix)
|
||||||
|
decodedToken, err := base64.RawURLEncoding.DecodeString(encodedToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.New("invalid base64.bearer.authorization token encoding")
|
||||||
|
}
|
||||||
|
if !utf8.Valid(decodedToken) {
|
||||||
|
return "", errors.New("invalid base64.bearer.authorization token")
|
||||||
|
}
|
||||||
|
token = string(decodedToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(token) == 0 {
|
||||||
|
return "", errors.New("empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must pass at least one other subprotocol so that we can remove the one containing the bearer token,
|
||||||
|
// and there is at least one to echo back to the client
|
||||||
|
if len(filteredProtocols) == 0 {
|
||||||
|
return "", errors.New("missing additional subprotocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc6455#section-11.3.4 indicates the Sec-WebSocket-Protocol header may appear multiple times
|
||||||
|
// in a request, and is logically the same as a single Sec-WebSocket-Protocol header field that contains all values
|
||||||
|
req.Header.Set(protocolHeader, strings.Join(filteredProtocols, ","))
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportedProtocol2pass is Protocol that will be sent back on upgrade mechanism.
|
||||||
|
const supportedProtocol2pass = "2pass.io"
|
||||||
|
|
||||||
|
var upgrader2pass = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 4 * 1024,
|
||||||
|
WriteBufferSize: 4 * 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
|
||||||
|
|
||||||
|
if allowedOrigin != "" {
|
||||||
|
return r.Header.Get("Origin") == allowedOrigin
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
Subprotocols: []string{supportedProtocol2pass},
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsUpgraderForProtocol(req *http.Request) (websocket.Upgrader, error) {
|
||||||
|
protocols := strings.Split(req.Header.Get(protocolHeader), ",")
|
||||||
|
if slices.Contains(protocols, supportedProtocol2pass) {
|
||||||
|
return upgrader2pass, nil
|
||||||
|
}
|
||||||
|
return websocket.Upgrader{}, fmt.Errorf("upgrader not available for protocols: %v", protocols)
|
||||||
|
}
|
63
internal/pass/pairing/websocket_test.go
Normal file
63
internal/pass/pairing/websocket_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package pairing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_tokenFromWSProtocol(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocolHeader string
|
||||||
|
assertFn func(t *testing.T, token string, err error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token with additional subprotocol",
|
||||||
|
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,ws",
|
||||||
|
assertFn: func(t *testing.T, token string, err error) {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "test1", token)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing header",
|
||||||
|
assertFn: func(t *testing.T, token string, err error) {
|
||||||
|
require.ErrorContains(t, err, "empty token")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid encoding",
|
||||||
|
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdA==",
|
||||||
|
assertFn: func(t *testing.T, token string, err error) {
|
||||||
|
require.ErrorContains(t, err, "invalid base64.bearer.authorization token encoding")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing other protocol",
|
||||||
|
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE",
|
||||||
|
assertFn: func(t *testing.T, token string, err error) {
|
||||||
|
require.ErrorContains(t, err, "missing additional subprotocol")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "double authorization header",
|
||||||
|
protocolHeader: "base64url.bearer.authorization.2pass.io.dGVzdDE,base64url.bearer.authorization.2pass.io.dGVzdDE",
|
||||||
|
assertFn: func(t *testing.T, token string, err error) {
|
||||||
|
require.ErrorContains(t, err, "multiple base64.bearer.authorization tokens specified")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest("GET", "http://localhost", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if protocolHeader != "" {
|
||||||
|
req.Header.Set(protocolHeader, tt.protocolHeader)
|
||||||
|
}
|
||||||
|
got, err := tokenFromWSProtocol(req)
|
||||||
|
tt.assertFn(t, got, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -27,11 +27,11 @@ func NewServer(addr string) *Server {
|
|||||||
context.Status(200)
|
context.Status(200)
|
||||||
})
|
})
|
||||||
|
|
||||||
router.POST("/browser_extension/configure", pairing.BrowserExtensionConfigureHandler(pairingApp))
|
router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp))
|
||||||
router.GET("/browser_extension/wait_for_connection", pairing.BrowserExtensionWaitForConnHandler(pairingApp))
|
router.GET("/browser_extension/wait_for_connection", pairing.ExtensionWaitForConnWSHandler(pairingApp))
|
||||||
router.GET("/browser_extension/proxy_to_mobile", pairing.BrowserExtensionProxyHandler(pairingApp, proxyApp))
|
router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyApp))
|
||||||
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
|
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
|
||||||
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyHandler(pairingApp, proxyApp))
|
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyApp))
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
router: router,
|
router: router,
|
||||||
|
Loading…
Reference in New Issue
Block a user