mirror of
https://github.com/twofas/2fas-server.git
synced 2025-01-05 22:15:56 +01:00
parent
4c16a1eacc
commit
34d87a852a
@ -1,8 +1,12 @@
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
type PassConfig struct {
|
||||
Addr string `envconfig:"PASS_ADDR" default:":8082"`
|
||||
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
|
||||
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
|
||||
AWSRegion string `envconfig:"AWS_REGION" default:"us-east-2"`
|
||||
Addr string `envconfig:"PASS_ADDR" default:":8082"`
|
||||
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
|
||||
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
|
||||
AWSRegion string `envconfig:"AWS_REGION" default:"us-east-2"`
|
||||
FakeMobilePush bool `envconfig:"FAKE_MOBILE_PUSH" default:"false"`
|
||||
PairingRequestTokenValidityDuration time.Duration `envconfig:"PAIRING_REQUEST_TOKEN_VALIDITY_DURATION" default:"8765h"` // 1 year
|
||||
}
|
||||
|
149
internal/pass/connection/proxy.go
Normal file
149
internal/pass/connection/proxy.go
Normal file
@ -0,0 +1,149 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
|
||||
// Maximum message size allowed from peer.
|
||||
maxMessageSize = 4 * 1048
|
||||
)
|
||||
|
||||
var (
|
||||
newline = []byte{'\n'}
|
||||
space = []byte{' '}
|
||||
|
||||
acceptedCloseStatus = []int{
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived,
|
||||
websocket.CloseAbnormalClosure,
|
||||
}
|
||||
)
|
||||
|
||||
// proxy is a responsible for reading from read chan and sending it over wsConn
|
||||
// and reading fom wsChan and sending it over send chan
|
||||
type proxy struct {
|
||||
send chan []byte
|
||||
read chan []byte
|
||||
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func startProxy(wsConn *websocket.Conn, send, read chan []byte) {
|
||||
proxy := &proxy{
|
||||
send: send,
|
||||
read: read,
|
||||
conn: wsConn,
|
||||
}
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
proxy.writePump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
proxy.readPump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
disconnectAfter := 3 * time.Minute
|
||||
timeout := time.After(disconnectAfter)
|
||||
|
||||
<-timeout
|
||||
logging.Info("Connection closed after", disconnectAfter)
|
||||
|
||||
proxy.conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// readPump pumps messages from the websocket proxy to send.
|
||||
//
|
||||
// The application runs readPump in a per-proxy goroutine. The application
|
||||
// ensures that there is at most one reader on a proxy by executing all
|
||||
// reads from this goroutine.
|
||||
func (p *proxy) readPump() {
|
||||
defer func() {
|
||||
p.conn.Close()
|
||||
close(p.send)
|
||||
}()
|
||||
|
||||
p.conn.SetReadLimit(maxMessageSize)
|
||||
p.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
p.conn.SetPongHandler(func(string) error {
|
||||
p.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := p.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) {
|
||||
logging.WithFields(logging.Fields{
|
||||
"reason": err.Error(),
|
||||
}).Error("Websocket proxy closed unexpected")
|
||||
} else {
|
||||
logging.WithFields(logging.Fields{
|
||||
"reason": err.Error(),
|
||||
}).Info("Connection closed")
|
||||
}
|
||||
break
|
||||
}
|
||||
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
|
||||
p.send <- message
|
||||
}
|
||||
}
|
||||
|
||||
// writePump pumps messages from the read chan to the websocket proxy.
|
||||
//
|
||||
// A goroutine running writePump is started for each proxy. The
|
||||
// application ensures that there is at most one writer to a proxy by
|
||||
// executing all writes from this goroutine.
|
||||
func (p *proxy) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
p.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-p.read:
|
||||
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
// The hub closed the channel.
|
||||
p.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := p.conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := p.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
50
internal/pass/connection/proxy_pool.go
Normal file
50
internal/pass/connection/proxy_pool.go
Normal file
@ -0,0 +1,50 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type proxyPool struct {
|
||||
mu sync.Mutex
|
||||
proxies map[string]*proxyPair
|
||||
}
|
||||
|
||||
// registerMobileConn register proxyPair if not existing in pool and returns it.
|
||||
func (pp *proxyPool) getOrCreateProxyPair(id string) *proxyPair {
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
v, ok := pp.proxies[id]
|
||||
if !ok {
|
||||
v = initProxyPair()
|
||||
}
|
||||
pp.proxies[id] = v
|
||||
return v
|
||||
}
|
||||
|
||||
func (pp *proxyPool) deleteExpiresPairs() {
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
|
||||
for key, pair := range pp.proxies {
|
||||
if time.Now().After(pair.expiresAt) {
|
||||
delete(pp.proxies, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type proxyPair struct {
|
||||
toMobileDataCh chan []byte
|
||||
toExtensionDataCh chan []byte
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
|
||||
func initProxyPair() *proxyPair {
|
||||
const proxyTimeout = 3 * time.Minute
|
||||
return &proxyPair{
|
||||
toMobileDataCh: make(chan []byte),
|
||||
toExtensionDataCh: make(chan []byte),
|
||||
expiresAt: time.Now().Add(proxyTimeout),
|
||||
}
|
||||
}
|
58
internal/pass/connection/proxy_server.go
Normal file
58
internal/pass/connection/proxy_server.go
Normal file
@ -0,0 +1,58 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
)
|
||||
|
||||
// ProxyServer manages proxy connections between Browser Extension and Mobile.
|
||||
type ProxyServer struct {
|
||||
proxyPool *proxyPool
|
||||
idLabel string
|
||||
}
|
||||
|
||||
func NewProxyServer(idLabel string) *ProxyServer {
|
||||
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
for {
|
||||
<-ticker.C
|
||||
proxyPool.deleteExpiresPairs()
|
||||
}
|
||||
}()
|
||||
return &ProxyServer{
|
||||
proxyPool: proxyPool,
|
||||
idLabel: idLabel,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, id string) error {
|
||||
log := logging.WithField(p.idLabel, id)
|
||||
conn, err := Upgrade(w, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade proxy: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Starting ServeExtensionProxyToMobileWS")
|
||||
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
|
||||
startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, id string) error {
|
||||
conn, err := Upgrade(w, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade proxy: %w", err)
|
||||
}
|
||||
|
||||
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id)
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
|
||||
|
||||
startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh)
|
||||
|
||||
return nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package pairing
|
||||
package connection
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
@ -18,16 +18,16 @@ const bearerProtocolPrefix = "base64url.bearer.authorization.2pass.io."
|
||||
|
||||
var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol")
|
||||
|
||||
// tokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
|
||||
// TokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
|
||||
// It is used because websocket API in browser does not allow to pass headers.
|
||||
// https://github.com/kubernetes/kubernetes/commit/714f97d7baf4975ad3aa47735a868a81a984d1f0
|
||||
//
|
||||
// Client provides a bearer token as a subprotocol in the format
|
||||
// "base64url.bearer.authorization.2pass.io.<base64url-without-padding(bearer-token)>".
|
||||
// This function also modified request header by removing authorization header.
|
||||
// Server according to spec must return at least 1 protocol to client, so this fn checks
|
||||
// Server according to spec must return at least 1 protocol to proxy, so this fn checks
|
||||
// if at least 1 protocol is sent, beside authorization header.
|
||||
func tokenFromWSProtocol(req *http.Request) (string, error) {
|
||||
func TokenFromWSProtocol(req *http.Request) (string, error) {
|
||||
token := ""
|
||||
sawTokenProtocol := false
|
||||
filteredProtocols := []string{}
|
||||
@ -62,7 +62,7 @@ func tokenFromWSProtocol(req *http.Request) (string, error) {
|
||||
}
|
||||
|
||||
// Must pass at least one other subprotocol so that we can remove the one containing the bearer token,
|
||||
// and there is at least one to echo back to the client
|
||||
// and there is at least one to echo back to the proxy
|
||||
if len(filteredProtocols) == 0 {
|
||||
return "", errors.New("missing additional subprotocol")
|
||||
}
|
||||
@ -91,10 +91,14 @@ var upgrader2pass = websocket.Upgrader{
|
||||
Subprotocols: []string{supportedProtocol2pass},
|
||||
}
|
||||
|
||||
func wsUpgraderForProtocol(req *http.Request) (websocket.Upgrader, error) {
|
||||
func Upgrade(w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) {
|
||||
protocols := strings.Split(req.Header.Get(protocolHeader), ",")
|
||||
if slices.Contains(protocols, supportedProtocol2pass) {
|
||||
return upgrader2pass, nil
|
||||
if !slices.Contains(protocols, supportedProtocol2pass) {
|
||||
return nil, fmt.Errorf("upgrader not available for protocols: %v", protocols)
|
||||
}
|
||||
return websocket.Upgrader{}, fmt.Errorf("upgrader not available for protocols: %v", protocols)
|
||||
conn, err := upgrader2pass.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upgrade proxy: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package pairing
|
||||
package connection
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_tokenFromWSProtocol(t *testing.T) {
|
||||
func Test_TokenFromWSProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocolHeader string
|
||||
@ -56,7 +56,7 @@ func Test_tokenFromWSProtocol(t *testing.T) {
|
||||
if protocolHeader != "" {
|
||||
req.Header.Set(protocolHeader, tt.protocolHeader)
|
||||
}
|
||||
got, err := tokenFromWSProtocol(req)
|
||||
got, err := TokenFromWSProtocol(req)
|
||||
tt.assertFn(t, got, err)
|
||||
})
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection"
|
||||
)
|
||||
|
||||
func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
@ -35,27 +36,31 @@ func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
|
||||
func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||
token, err := connection.TokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
gCtx.Status(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
extensionID, err := pairingApp.VerifyPairingToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify pairing token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
gCtx.Status(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
|
||||
if err := pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID); err != nil {
|
||||
logging.Errorf("Failed serve ws: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *connection.ProxyServer) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||
token, err := connection.TokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
@ -77,7 +82,11 @@ func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFu
|
||||
gCtx.String(http.StatusForbidden, "Pairing is not yet done")
|
||||
return
|
||||
}
|
||||
proxyApp.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, extensionID, pairingInfo.Device.DeviceID)
|
||||
if err := proxyApp.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID); err != nil {
|
||||
logging.Errorf("Failed to serve ws: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,9 +125,9 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
func MobileProxyWSHandler(pairingApp *Pairing, proxy *connection.ProxyServer) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := tokenFromWSProtocol(gCtx.Request)
|
||||
token, err := connection.TokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
@ -141,7 +150,11 @@ func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc
|
||||
gCtx.String(http.StatusForbidden, "Pairing is not yet done")
|
||||
return
|
||||
}
|
||||
proxyApp.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID)
|
||||
if err := proxy.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID); err != nil {
|
||||
log.Errorf("Failed to serve ws: %w", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,13 +8,16 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection"
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
)
|
||||
|
||||
type Pairing struct {
|
||||
store store
|
||||
signSvc *sign.Service
|
||||
store store
|
||||
signSvc *sign.Service
|
||||
pairingRequestTokenValidityDuration time.Duration
|
||||
}
|
||||
|
||||
type store interface {
|
||||
@ -24,10 +27,11 @@ type store interface {
|
||||
SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error
|
||||
}
|
||||
|
||||
func NewPairingApp(signService *sign.Service) *Pairing {
|
||||
func NewApp(signService *sign.Service, pairingRequestTokenValidityDuration time.Duration) *Pairing {
|
||||
return &Pairing{
|
||||
store: NewMemoryStore(),
|
||||
signSvc: signService,
|
||||
store: NewMemoryStore(),
|
||||
signSvc: signService,
|
||||
pairingRequestTokenValidityDuration: pairingRequestTokenValidityDuration,
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,31 +81,25 @@ type ExtensionWaitForConnectionInput struct {
|
||||
|
||||
type WaitForConnectionResponse struct {
|
||||
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
|
||||
BrowserExtensionSyncToken string `json:"browser_extension_sync_token"`
|
||||
Status string `json:"status"`
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
|
||||
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
|
||||
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) error {
|
||||
log := logging.WithField("extension_id", extID)
|
||||
upgrader, err := wsUpgraderForProtocol(r)
|
||||
conn, err := connection.Upgrade(w, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
return fmt.Errorf("failed to upgrade connection: %w", err)
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
log.Info("Starting pairing WS")
|
||||
|
||||
if deviceID, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
|
||||
if err := p.sendTokenAndCloseConn(extID, deviceID, conn); err != nil {
|
||||
if pairing, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
|
||||
if err := p.sendTokenAndCloseConn(extID, pairing, conn); err != nil {
|
||||
log.Errorf("Failed to send token: %v", err)
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
@ -116,43 +114,56 @@ func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID s
|
||||
select {
|
||||
case <-maxWaitC:
|
||||
log.Info("Closing paring ws after timeout")
|
||||
return
|
||||
return nil
|
||||
case <-connectedCheckTicker.C:
|
||||
if deviceID, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
|
||||
if err := p.sendTokenAndCloseConn(extID, deviceID, conn); err != nil {
|
||||
if pairing, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
|
||||
if err := p.sendTokenAndCloseConn(extID, pairing, conn); err != nil {
|
||||
log.Errorf("Failed to send token: %v", err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
log.WithField("device_id", deviceID).Infof("Paring ws finished")
|
||||
return
|
||||
log.WithField("device_id", pairing.Device.DeviceID).Infof("Paring ws finished")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logrus.Entry) (string, bool) {
|
||||
func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logrus.Entry) (PairingInfo, bool) {
|
||||
pairingInfo, err := p.store.GetPairingInfo(ctx, extID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get pairing info")
|
||||
return "", false
|
||||
return PairingInfo{}, false
|
||||
}
|
||||
return pairingInfo.Device.DeviceID, pairingInfo.IsPaired()
|
||||
return pairingInfo, pairingInfo.IsPaired()
|
||||
}
|
||||
|
||||
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
|
||||
func (p *Pairing) sendTokenAndCloseConn(extID string, pairingInfo PairingInfo, conn *websocket.Conn) error {
|
||||
extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
|
||||
ConnectionID: extID,
|
||||
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to generate ext proxy token: %v", err)
|
||||
return fmt.Errorf("failed to generate ext proxy token: %v", err)
|
||||
}
|
||||
var syncToken string
|
||||
if pairingInfo.Device.FCMToken != "" {
|
||||
syncToken, err = p.signSvc.SignAndEncode(sign.Message{
|
||||
ConnectionID: pairingInfo.Device.FCMToken,
|
||||
ExpiresAt: time.Now().Add(p.pairingRequestTokenValidityDuration),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionSyncRequest,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate proxy sync request token: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(WaitForConnectionResponse{
|
||||
BrowserExtensionProxyToken: extProxyToken,
|
||||
BrowserExtensionSyncToken: syncToken,
|
||||
Status: "ok",
|
||||
DeviceID: deviceID,
|
||||
DeviceID: pairingInfo.Device.DeviceID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write to extension: %v", err)
|
||||
}
|
||||
@ -191,5 +202,6 @@ func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest,
|
||||
}); err != nil {
|
||||
return ConfirmPairingResponse{}, err
|
||||
}
|
||||
|
||||
return ConfirmPairingResponse{ProxyToken: mobileProxyToken}, nil
|
||||
}
|
||||
|
@ -1,269 +0,0 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
proxyPool *proxyPool
|
||||
}
|
||||
|
||||
func NewProxy() *Proxy {
|
||||
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
for {
|
||||
<-ticker.C
|
||||
proxyPool.deleteExpiresPairs()
|
||||
}
|
||||
}()
|
||||
return &Proxy{
|
||||
proxyPool: proxyPool,
|
||||
}
|
||||
}
|
||||
|
||||
type proxyPool struct {
|
||||
mu sync.Mutex
|
||||
proxies map[string]*proxyPair
|
||||
}
|
||||
|
||||
// registerMobileConn register proxyPair if not existing in pool and returns it.
|
||||
func (pp *proxyPool) getOrCreateProxyPair(deviceID string) *proxyPair {
|
||||
// TODO: handle delete.
|
||||
// TODO: right now two connections to the same WS results in race for messages/ decide if we want multiple conn or not.
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
v, ok := pp.proxies[deviceID]
|
||||
if !ok {
|
||||
v = initProxyPair()
|
||||
}
|
||||
pp.proxies[deviceID] = v
|
||||
return v
|
||||
}
|
||||
|
||||
func (pp *proxyPool) deleteExpiresPairs() {
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
|
||||
for key, pair := range pp.proxies {
|
||||
if time.Now().After(pair.expiresAt) {
|
||||
delete(pp.proxies, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type proxyPair struct {
|
||||
toMobileDataCh chan []byte
|
||||
toExtensionDataCh chan []byte
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
|
||||
func initProxyPair() *proxyPair {
|
||||
const proxyTimeout = 3 * time.Minute
|
||||
return &proxyPair{
|
||||
toMobileDataCh: make(chan []byte),
|
||||
toExtensionDataCh: make(chan []byte),
|
||||
expiresAt: time.Now().Add(proxyTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
newline = []byte{'\n'}
|
||||
space = []byte{' '}
|
||||
|
||||
acceptedCloseStatus = []int{
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived,
|
||||
websocket.CloseAbnormalClosure,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
|
||||
// Maximum message size allowed from peer.
|
||||
maxMessageSize = 4 * 1048
|
||||
)
|
||||
|
||||
// client is a responsible for reading from read chan and sending it over wsConn
|
||||
// and reading fom wsChan and sending it over send chan
|
||||
type client struct {
|
||||
send chan []byte
|
||||
read chan []byte
|
||||
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func newClient(wsConn *websocket.Conn, send, read chan []byte) *client {
|
||||
return &client{
|
||||
send: send,
|
||||
read: read,
|
||||
conn: wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
// readPump pumps messages from the websocket connection to send.
|
||||
//
|
||||
// The application runs readPump in a per-connection goroutine. The application
|
||||
// ensures that there is at most one reader on a connection by executing all
|
||||
// reads from this goroutine.
|
||||
func (c *client) readPump() {
|
||||
defer func() {
|
||||
c.conn.Close()
|
||||
close(c.send)
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) {
|
||||
logging.WithFields(logging.Fields{
|
||||
"reason": err.Error(),
|
||||
}).Error("Websocket connection closed unexpected")
|
||||
} else {
|
||||
logging.WithFields(logging.Fields{
|
||||
"reason": err.Error(),
|
||||
}).Info("Connection closed")
|
||||
}
|
||||
break
|
||||
}
|
||||
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
|
||||
c.send <- message
|
||||
}
|
||||
}
|
||||
|
||||
// writePump pumps messages from the read chan to the websocket connection.
|
||||
//
|
||||
// A goroutine running writePump is started for each connection. The
|
||||
// application ensures that there is at most one writer to a connection by
|
||||
// executing all writes from this goroutine.
|
||||
func (c *client) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.read:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
// The hub closed the channel.
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
|
||||
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
|
||||
upgrader, err := wsUpgraderForProtocol(r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Starting ServeExtensionProxyToMobileWS")
|
||||
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(deviceID)
|
||||
client := newClient(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh)
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.writePump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.readPump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
disconnectAfter := 3 * time.Minute
|
||||
timeout := time.After(disconnectAfter)
|
||||
|
||||
<-timeout
|
||||
logging.Info("Connection closed after", disconnectAfter)
|
||||
|
||||
client.conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
|
||||
upgrader, err := wsUpgraderForProtocol(r)
|
||||
if err != nil {
|
||||
logging.Error(err)
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", deviceID)
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(deviceID)
|
||||
|
||||
client := newClient(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh)
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.writePump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.readPump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
disconnectAfter := 3 * time.Minute
|
||||
timeout := time.After(disconnectAfter)
|
||||
|
||||
<-timeout
|
||||
logging.Info("Connection closed after", disconnectAfter)
|
||||
|
||||
client.conn.Close()
|
||||
})
|
||||
}
|
@ -7,12 +7,15 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
|
||||
"github.com/twofas/2fas-server/config"
|
||||
httphelpers "github.com/twofas/2fas-server/internal/common/http"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection"
|
||||
"github.com/twofas/2fas-server/internal/pass/pairing"
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
"github.com/twofas/2fas-server/internal/pass/sync"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@ -48,8 +51,11 @@ func NewServer(cfg config.PassConfig) *Server {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
pairingApp := pairing.NewPairingApp(signSvc)
|
||||
proxyApp := pairing.NewProxy()
|
||||
pairingApp := pairing.NewApp(signSvc, cfg.PairingRequestTokenValidityDuration)
|
||||
proxyPairingApp := connection.NewProxyServer("device_id")
|
||||
|
||||
syncApp := sync.NewApp(signSvc, cfg.FakeMobilePush)
|
||||
proxySyncApp := connection.NewProxyServer("fcm_token")
|
||||
|
||||
router := gin.New()
|
||||
router.Use(recovery.RecoveryMiddleware())
|
||||
@ -62,11 +68,29 @@ func NewServer(cfg config.PassConfig) *Server {
|
||||
context.Status(200)
|
||||
})
|
||||
|
||||
router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp))
|
||||
// Deprecated paths start here.
|
||||
router.GET("/browser_extension/wait_for_connection", pairing.ExtensionWaitForConnWSHandler(pairingApp))
|
||||
router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyApp))
|
||||
router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyPairingApp))
|
||||
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
|
||||
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyApp))
|
||||
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyPairingApp))
|
||||
// Deprecated paths end here.
|
||||
|
||||
router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp))
|
||||
|
||||
router.GET("/browser_extension/pairing/wait", pairing.ExtensionWaitForConnWSHandler(pairingApp))
|
||||
router.GET("/browser_extension/pairing/proxy", pairing.ExtensionProxyWSHandler(pairingApp, proxyPairingApp))
|
||||
router.POST("/mobile/pairing/confirm", pairing.MobileConfirmHandler(pairingApp))
|
||||
router.GET("/mobile/pairing/proxy", pairing.MobileProxyWSHandler(pairingApp, proxyPairingApp))
|
||||
|
||||
router.GET("/browser_extension/sync/request", sync.ExtensionRequestSync(syncApp))
|
||||
router.GET("/browser_extension/sync/proxy", sync.ExtensionProxyWSHandler(syncApp, proxySyncApp))
|
||||
router.POST("/mobile/sync/confirm", sync.MobileConfirmHandler(syncApp))
|
||||
router.GET("/mobile/sync/proxy", sync.MobileProxyWSHandler(syncApp, proxySyncApp))
|
||||
|
||||
if cfg.FakeMobilePush {
|
||||
logging.Info("Enabled '/mobile/sync/:fcm/token' endpoint. This should happen in test env only!")
|
||||
router.GET("/mobile/sync/:fcm/token", sync.MobileGenerateSyncToken(syncApp))
|
||||
}
|
||||
|
||||
return &Server{
|
||||
router: router,
|
||||
|
@ -56,8 +56,12 @@ func NewService(keyID string, client *kms.KMS) (*Service, error) {
|
||||
type ConnectionType string
|
||||
|
||||
const (
|
||||
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
|
||||
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
|
||||
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
|
||||
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
|
||||
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
|
||||
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
|
||||
ConnectionTypeBrowserExtensionSyncRequest ConnectionType = "be/sync/request"
|
||||
ConnectionTypeBrowserExtensionSync ConnectionType = "be/sync/proxy"
|
||||
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
|
||||
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
|
||||
ConnectionTypeMobileSyncConfirm ConnectionType = "mobile/sync/confirm"
|
||||
ConnectionTypeMobileSyncProxy ConnectionType = "mobile/sync/proxy"
|
||||
)
|
||||
|
44
internal/pass/sync/auth.go
Normal file
44
internal/pass/sync/auth.go
Normal file
@ -0,0 +1,44 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
)
|
||||
|
||||
// VerifyExtRequestSyncToken verifies sync request token and returns fcm_token.
|
||||
func (s *Syncing) VerifyExtRequestSyncToken(ctx context.Context, proxyToken string) (string, error) {
|
||||
fcmToken, err := s.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionSyncRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||
}
|
||||
return fcmToken, nil
|
||||
}
|
||||
|
||||
// VerifyExtSyncToken verifies sync token and returns fcm_token.
|
||||
func (s *Syncing) VerifyExtSyncToken(ctx context.Context, proxyToken string) (string, error) {
|
||||
fcmToken, err := s.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionSync)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||
}
|
||||
return fcmToken, nil
|
||||
}
|
||||
|
||||
// VerifyMobileSyncConfirmToken verifies mobile token and returns connection id.
|
||||
func (s *Syncing) VerifyMobileSyncConfirmToken(ctx context.Context, proxyToken string) (string, error) {
|
||||
extensionID, err := s.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileSyncConfirm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||
}
|
||||
return extensionID, nil
|
||||
}
|
||||
|
||||
// VerifyMobileSyncProxyToken verifies mobile token and returns connection id.
|
||||
func (s *Syncing) VerifyMobileSyncProxyToken(ctx context.Context, proxyToken string) (string, error) {
|
||||
extensionID, err := s.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileSyncProxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check token signature: %w", err)
|
||||
}
|
||||
return extensionID, nil
|
||||
}
|
142
internal/pass/sync/handlers.go
Normal file
142
internal/pass/sync/handlers.go
Normal file
@ -0,0 +1,142 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection"
|
||||
)
|
||||
|
||||
func ExtensionRequestSync(syncingApp *Syncing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := connection.TokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
fcmToken, err := syncingApp.VerifyExtRequestSyncToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify proxy token: %v", err)
|
||||
gCtx.String(http.StatusUnauthorized, "Invalid auth token")
|
||||
return
|
||||
}
|
||||
|
||||
if err := syncingApp.ServeSyncingRequestWS(gCtx.Writer, gCtx.Request, fcmToken); err != nil {
|
||||
logging.Errorf("Failed to verify proxy token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExtensionProxyWSHandler(syncingApp *Syncing, proxy *connection.ProxyServer) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := connection.TokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
fcmToken, err := syncingApp.VerifyExtSyncToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify proxy token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
ok := syncingApp.isSyncConfirmed(gCtx, fcmToken)
|
||||
if !ok {
|
||||
gCtx.String(http.StatusForbidden, "Syncing is not yet done")
|
||||
return
|
||||
}
|
||||
if err := proxy.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, fcmToken); err != nil {
|
||||
logging.Errorf("Failed to serve ws: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MobileConfirmHandler(syncApp *Syncing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
fcmToken, err := syncApp.VerifyMobileSyncConfirmToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify connection token: %v", err)
|
||||
gCtx.Status(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
resp, err := syncApp.confirmSync(gCtx, fcmToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, noSyncRequestErr) {
|
||||
gCtx.String(http.StatusBadRequest, "no sync request was created for this token")
|
||||
return
|
||||
}
|
||||
logging.Errorf("Failed to ConfirmPairing: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
gCtx.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func MobileProxyWSHandler(syncingApp *Syncing, proxy *connection.ProxyServer) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
token, err := connection.TokenFromWSProtocol(gCtx.Request)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
fcmToken, err := syncingApp.VerifyMobileSyncConfirmToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Invalid connection token: %v", err)
|
||||
gCtx.Status(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
ok := syncingApp.isSyncConfirmed(gCtx, fcmToken)
|
||||
if !ok {
|
||||
gCtx.String(http.StatusForbidden, "Syncing is not yet done")
|
||||
return
|
||||
}
|
||||
if err := proxy.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, fcmToken); err != nil {
|
||||
logging.Errorf("Failed to serve ws: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MobileGenerateSyncToken(syncingApp *Syncing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
fcm := gCtx.Param("fcm")
|
||||
if fcm == "" {
|
||||
gCtx.Status(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := syncingApp.sendMobileToken(fcm, gCtx.Writer); err != nil {
|
||||
logging.Errorf("Failed to send mobile token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tokenFromRequest(gCtx *gin.Context) (string, error) {
|
||||
tokenHeader := gCtx.GetHeader("Authorization")
|
||||
if tokenHeader == "" {
|
||||
return "", errors.New("missing Authorization header")
|
||||
}
|
||||
splitToken := strings.Split(tokenHeader, "Bearer ")
|
||||
if len(splitToken) != 2 {
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return "", errors.New("missing 'Bearer: value'")
|
||||
}
|
||||
return splitToken[1], nil
|
||||
}
|
61
internal/pass/sync/memorystore.go
Normal file
61
internal/pass/sync/memorystore.go
Normal file
@ -0,0 +1,61 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryStore keeps in memory pairing between extension and mobile.
|
||||
//
|
||||
// TODO: check ttlcache pkg, right now entries are not invalidated.
|
||||
type MemoryStore struct {
|
||||
mu sync.Mutex
|
||||
extensionsMap map[string]Item
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
FCMToken string
|
||||
Expires time.Time
|
||||
Confirmed bool
|
||||
}
|
||||
|
||||
func NewMemoryStore() *MemoryStore {
|
||||
return &MemoryStore{
|
||||
extensionsMap: make(map[string]Item),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) RequestSync(fcmToken string) {
|
||||
s.setItem(fcmToken, Item{FCMToken: fcmToken})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ConfirmSync(fcmToken string) bool {
|
||||
v, ok := s.getItem(fcmToken)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
v.Confirmed = true
|
||||
s.setItem(fcmToken, v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *MemoryStore) IsSyncCofirmed(fcmToken string) bool {
|
||||
v, ok := s.getItem(fcmToken)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return v.Confirmed
|
||||
}
|
||||
|
||||
func (s *MemoryStore) setItem(key string, item Item) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.extensionsMap[key] = item
|
||||
}
|
||||
|
||||
func (s *MemoryStore) getItem(key string) (Item, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, ok := s.extensionsMap[key]
|
||||
return v, ok
|
||||
}
|
180
internal/pass/sync/sync.go
Normal file
180
internal/pass/sync/sync.go
Normal file
@ -0,0 +1,180 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection"
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
)
|
||||
|
||||
type Syncing struct {
|
||||
store store
|
||||
signSvc *sign.Service
|
||||
}
|
||||
|
||||
type store interface {
|
||||
RequestSync(fmtToken string)
|
||||
ConfirmSync(fmtToken string) bool
|
||||
IsSyncCofirmed(fmtToken string) bool
|
||||
}
|
||||
|
||||
func NewApp(signService *sign.Service, fakeMobilePush bool) *Syncing {
|
||||
return &Syncing{
|
||||
store: NewMemoryStore(),
|
||||
signSvc: signService,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
syncTokenValidityDuration = 3 * time.Minute
|
||||
)
|
||||
|
||||
type ConfigureBrowserExtensionRequest struct {
|
||||
ExtensionID string `json:"extension_id"`
|
||||
}
|
||||
|
||||
type ConfigureBrowserExtensionResponse struct {
|
||||
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
|
||||
ConnectionToken string `json:"connection_token"`
|
||||
}
|
||||
|
||||
type ExtensionWaitForConnectionInput struct {
|
||||
ResponseWriter http.ResponseWriter
|
||||
HttpReq *http.Request
|
||||
}
|
||||
|
||||
type RequestSyncResponse struct {
|
||||
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type MobileSyncPayload struct {
|
||||
MobileSyncToken string `json:"mobile_sync_token"`
|
||||
}
|
||||
|
||||
func (s *Syncing) ServeSyncingRequestWS(w http.ResponseWriter, r *http.Request, fcmToken string) error {
|
||||
log := logging.WithField("fcm_token", fcmToken)
|
||||
conn, err := connection.Upgrade(w, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade connection: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
log.Infof("Starting sync request WS for %q", fcmToken)
|
||||
s.requestSync(r.Context(), fcmToken)
|
||||
|
||||
if syncDone := s.isSyncConfirmed(r.Context(), fcmToken); syncDone {
|
||||
if err := s.sendTokenAndCloseConn(fcmToken, conn); err != nil {
|
||||
log.Errorf("Failed to send token: %v", err)
|
||||
}
|
||||
log.Infof("Paring ws finished")
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxWaitTime = 3 * time.Minute
|
||||
checkIfConnectedInterval = time.Second
|
||||
)
|
||||
maxWaitC := time.After(maxWaitTime)
|
||||
connectedCheckTicker := time.NewTicker(checkIfConnectedInterval)
|
||||
defer connectedCheckTicker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-maxWaitC:
|
||||
log.Info("Closing paring ws after timeout")
|
||||
return nil
|
||||
case <-connectedCheckTicker.C:
|
||||
if syncConfirmed := s.isSyncConfirmed(r.Context(), fcmToken); syncConfirmed {
|
||||
if err := s.sendTokenAndCloseConn(fcmToken, conn); err != nil {
|
||||
log.Errorf("Failed to send token: %v", err)
|
||||
return nil
|
||||
}
|
||||
log.Infof("Paring ws finished")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Syncing) isSyncConfirmed(ctx context.Context, fcmToken string) bool {
|
||||
return s.store.IsSyncCofirmed(fcmToken)
|
||||
}
|
||||
|
||||
func (s *Syncing) requestSync(ctx context.Context, fcmToken string) {
|
||||
s.store.RequestSync(fcmToken)
|
||||
}
|
||||
|
||||
func (s *Syncing) sendTokenAndCloseConn(fcmToken string, conn *websocket.Conn) error {
|
||||
extProxyToken, err := s.signSvc.SignAndEncode(sign.Message{
|
||||
ConnectionID: fcmToken,
|
||||
ExpiresAt: time.Now().Add(syncTokenValidityDuration),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionSync,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate ext proxy token: %v", err)
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(RequestSyncResponse{
|
||||
BrowserExtensionProxyToken: extProxyToken,
|
||||
Status: "ok",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write to extension: %v", err)
|
||||
}
|
||||
return conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
}
|
||||
|
||||
func (s *Syncing) sendMobileToken(fcmToken string, resp http.ResponseWriter) error {
|
||||
extProxyToken, err := s.signSvc.SignAndEncode(sign.Message{
|
||||
ConnectionID: fcmToken,
|
||||
ExpiresAt: time.Now().Add(syncTokenValidityDuration),
|
||||
ConnectionType: sign.ConnectionTypeMobileSyncConfirm,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate ext proxy token: %v", err)
|
||||
}
|
||||
|
||||
bb, err := json.Marshal(struct {
|
||||
MobileSyncConfirmToken string `json:"mobile_sync_confirm_token"`
|
||||
}{
|
||||
MobileSyncConfirmToken: extProxyToken,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal the response: %v", err)
|
||||
}
|
||||
if _, err := resp.Write(bb); err != nil {
|
||||
return fmt.Errorf("failed to write the response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ConfirmSyncResponse struct {
|
||||
ProxyToken string `json:"proxy_token"`
|
||||
}
|
||||
|
||||
var noSyncRequestErr = errors.New("sync request was not created")
|
||||
|
||||
func (s *Syncing) confirmSync(ctx context.Context, fcmToken string) (ConfirmSyncResponse, error) {
|
||||
logging.Infof("Starting sync confirm for %q", fcmToken)
|
||||
|
||||
mobileProxyToken, err := s.signSvc.SignAndEncode(sign.Message{
|
||||
ConnectionID: fcmToken,
|
||||
ExpiresAt: time.Now().Add(syncTokenValidityDuration),
|
||||
ConnectionType: sign.ConnectionTypeMobileSyncConfirm,
|
||||
})
|
||||
if err != nil {
|
||||
return ConfirmSyncResponse{}, fmt.Errorf("failed to generate ext proxy token: %v", err)
|
||||
}
|
||||
if ok := s.store.ConfirmSync(fcmToken); !ok {
|
||||
return ConfirmSyncResponse{}, noSyncRequestErr
|
||||
}
|
||||
|
||||
return ConfirmSyncResponse{ProxyToken: mobileProxyToken}, nil
|
||||
}
|
@ -3,7 +3,7 @@
|
||||
apt install --assume-yes jq
|
||||
|
||||
AWS_REGION=us-east-1
|
||||
KEY_ALIAS=pass_service
|
||||
KEY_ALIAS=pass_service_signing_key
|
||||
|
||||
response=$(awslocal kms create-key \
|
||||
--region $AWS_REGION \
|
||||
|
158
tests/pass/http.go
Normal file
158
tests/pass/http.go
Normal file
@ -0,0 +1,158 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/avast/retry-go/v4"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
httpClient = http.DefaultClient
|
||||
)
|
||||
|
||||
func getApiURL() string {
|
||||
apiURL := os.Getenv("API_URL")
|
||||
if apiURL != "" {
|
||||
return apiURL
|
||||
}
|
||||
return "http://" + getPassAddr()
|
||||
}
|
||||
|
||||
func getPassAddr() string {
|
||||
addr := os.Getenv("PASS_ADDR")
|
||||
if addr != "" {
|
||||
return addr
|
||||
}
|
||||
return "localhost:8082"
|
||||
}
|
||||
|
||||
type ConfigureBrowserExtensionResponse struct {
|
||||
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
|
||||
ConnectionToken string `json:"connection_token"`
|
||||
}
|
||||
|
||||
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
|
||||
extensionID := uuid.NewString()
|
||||
if extensionIDFromEnv := os.Getenv("TEST_EXTENSION_ID"); extensionIDFromEnv != "" {
|
||||
extensionID = extensionIDFromEnv
|
||||
}
|
||||
req := struct {
|
||||
ExtensionID string `json:"extension_id"`
|
||||
}{
|
||||
ExtensionID: extensionID,
|
||||
}
|
||||
var resp ConfigureBrowserExtensionResponse
|
||||
|
||||
if err := request("POST", "/browser_extension/configure", "", req, &resp); err != nil {
|
||||
return resp, fmt.Errorf("failed to configure browser: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// confirmMobile confirms pairing and returns mobile proxy token.
|
||||
func confirmMobile(connectionToken, fcm string) (string, error) {
|
||||
deviceID := uuid.NewString()
|
||||
if deviceIDFromEnv := os.Getenv("TEST_DEVICE_ID"); deviceIDFromEnv != "" {
|
||||
deviceID = deviceIDFromEnv
|
||||
}
|
||||
|
||||
req := struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
FCMToken string `json:"fcm_token"`
|
||||
}{
|
||||
DeviceID: deviceID,
|
||||
FCMToken: fcm,
|
||||
}
|
||||
resp := struct {
|
||||
ProxyToken string `json:"proxy_token"`
|
||||
}{}
|
||||
|
||||
if err := request("POST", "/mobile/confirm", connectionToken, req, &resp); err != nil {
|
||||
return "", fmt.Errorf("failed to configure browser: %w", err)
|
||||
}
|
||||
|
||||
return resp.ProxyToken, nil
|
||||
}
|
||||
|
||||
// confirmSyncMobile confirms pairing and returns mobile proxy token.
|
||||
func confirmSyncMobile(connectionToken string) (string, error) {
|
||||
var result string
|
||||
|
||||
err := retry.Do(func() error {
|
||||
var err error
|
||||
result, err = confirmSyncMobileRequest(connectionToken)
|
||||
return err
|
||||
})
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func confirmSyncMobileRequest(connectionToken string) (string, error) {
|
||||
var resp struct {
|
||||
ProxyToken string `json:"proxy_token"`
|
||||
}
|
||||
|
||||
if err := request("POST", "/mobile/sync/confirm", connectionToken, nil, &resp); err != nil {
|
||||
return "", fmt.Errorf("failed to confirm mobile: %w", err)
|
||||
}
|
||||
|
||||
return resp.ProxyToken, nil
|
||||
}
|
||||
|
||||
func getMobileToken(fcm string) (string, error) {
|
||||
var resp struct {
|
||||
MobileSyncConfirmToken string `json:"mobile_sync_confirm_token"`
|
||||
}
|
||||
if err := request("GET", fmt.Sprintf("/mobile/sync/%s/token", fcm), "", nil, &resp); err != nil {
|
||||
return "", fmt.Errorf("failed to get mobile token")
|
||||
}
|
||||
|
||||
return resp.MobileSyncConfirmToken, nil
|
||||
}
|
||||
|
||||
func request(method, path, auth string, req, resp interface{}) error {
|
||||
url := getApiURL() + path
|
||||
var body io.Reader
|
||||
if req != nil {
|
||||
bb, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to request marshal: %w", err)
|
||||
}
|
||||
body = bytes.NewBuffer(bb)
|
||||
}
|
||||
httpReq, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
if auth != "" {
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", auth))
|
||||
}
|
||||
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
|
||||
return fmt.Errorf("failed perform the request: %w", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
bb, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read body from response: %w", err)
|
||||
}
|
||||
|
||||
if httpResp.StatusCode >= 300 {
|
||||
return fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
|
||||
}
|
||||
if err := json.Unmarshal(bb, &resp); err != nil {
|
||||
return fmt.Errorf("failed to decode the response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -26,7 +26,7 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kmsClient := kms.New(sess)
|
||||
srv, err := sign.NewService("alias/pass_service", kmsClient)
|
||||
srv, err := sign.NewService("alias/pass_service_signing_key", kmsClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -61,7 +61,7 @@ func TestSignAndVerify(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kmsClient := kms.New(sess)
|
||||
srv, err := sign.NewService("alias/pass_service", kmsClient)
|
||||
srv, err := sign.NewService("alias/pass_service_signing_key", kmsClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
60
tests/pass/pair_test.go
Normal file
60
tests/pass/pair_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestPairHappyFlow(t *testing.T) {
|
||||
resp, err := configureBrowserExtension()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to configure browser extension: %v", err)
|
||||
}
|
||||
|
||||
browserExtensionDone := make(chan struct{})
|
||||
mobileDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(browserExtensionDone)
|
||||
|
||||
extProxyToken, _, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
|
||||
if err != nil {
|
||||
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = proxyWebSocket(
|
||||
getWsURL()+"/browser_extension/proxy_to_mobile",
|
||||
extProxyToken,
|
||||
"sent from browser extension",
|
||||
"sent from mobile")
|
||||
if err != nil {
|
||||
t.Errorf("Browser Extension: proxy failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
}()
|
||||
go func() {
|
||||
defer close(mobileDone)
|
||||
|
||||
mobileProxyToken, err := confirmMobile(resp.ConnectionToken, uuid.NewString())
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: confirm failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = proxyWebSocket(
|
||||
getWsURL()+"/mobile/proxy_to_browser_extension",
|
||||
mobileProxyToken,
|
||||
"sent from mobile",
|
||||
"sent from browser extension",
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: proxy failed: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-browserExtensionDone
|
||||
<-mobileDone
|
||||
}
|
@ -1,262 +0,0 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type ConfigureBrowserExtensionResponse struct {
|
||||
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
|
||||
ConnectionToken string `json:"connection_token"`
|
||||
}
|
||||
|
||||
var (
|
||||
httpClient = http.DefaultClient
|
||||
wsDialer = websocket.DefaultDialer
|
||||
)
|
||||
|
||||
func getApiURL() string {
|
||||
apiURL := os.Getenv("API_URL")
|
||||
if apiURL != "" {
|
||||
return apiURL
|
||||
}
|
||||
return "http://" + getPassAddr()
|
||||
}
|
||||
|
||||
func getWsURL() string {
|
||||
wsURL := os.Getenv("WS_URL")
|
||||
if wsURL != "" {
|
||||
return wsURL
|
||||
}
|
||||
return "ws://" + getPassAddr()
|
||||
}
|
||||
|
||||
func getPassAddr() string {
|
||||
addr := os.Getenv("PASS_ADDR")
|
||||
if addr != "" {
|
||||
return addr
|
||||
}
|
||||
return "localhost:8082"
|
||||
}
|
||||
|
||||
func TestPassHappyFlow(t *testing.T) {
|
||||
resp, err := configureBrowserExtension()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to configure browser extension: %v", err)
|
||||
}
|
||||
|
||||
browserExtensionDone := make(chan struct{})
|
||||
mobileDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(browserExtensionDone)
|
||||
|
||||
extProxyToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
|
||||
if err != nil {
|
||||
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = proxyWebSocket(
|
||||
getWsURL()+"/browser_extension/proxy_to_mobile",
|
||||
extProxyToken,
|
||||
"sent from browser extension",
|
||||
"sent from mobile")
|
||||
if err != nil {
|
||||
t.Errorf("Browser Extension: proxy failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
}()
|
||||
go func() {
|
||||
defer close(mobileDone)
|
||||
|
||||
mobileProxyToken, err := confirmMobile(resp.ConnectionToken)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: confirm failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = proxyWebSocket(
|
||||
getWsURL()+"/mobile/proxy_to_browser_extension",
|
||||
mobileProxyToken,
|
||||
"sent from mobile",
|
||||
"sent from browser extension",
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: proxy failed: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-browserExtensionDone
|
||||
<-mobileDone
|
||||
}
|
||||
|
||||
func browserExtensionWaitForConfirm(token string) (string, error) {
|
||||
url := getWsURL() + "/browser_extension/wait_for_connection"
|
||||
|
||||
var resp struct {
|
||||
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
|
||||
Status string `json:"status"`
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
|
||||
conn, err := dialWS(url, token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading from connection: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(message, &resp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
const expectedStatus = "ok"
|
||||
if resp.Status != expectedStatus {
|
||||
return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
|
||||
}
|
||||
return resp.BrowserExtensionProxyToken, nil
|
||||
}
|
||||
|
||||
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
|
||||
url := getApiURL() + "/browser_extension/configure"
|
||||
|
||||
extensionID := uuid.NewString()
|
||||
if extensionIDFromEnv := os.Getenv("TEST_EXTENSION_ID"); extensionIDFromEnv != "" {
|
||||
extensionID = extensionIDFromEnv
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, extensionID))
|
||||
if err != nil {
|
||||
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
httpResp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed perform the request: %w", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
bb, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to read body from response: %w", err)
|
||||
}
|
||||
|
||||
if httpResp.StatusCode >= 300 {
|
||||
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
|
||||
}
|
||||
|
||||
var resp ConfigureBrowserExtensionResponse
|
||||
if err := json.Unmarshal(bb, &resp); err != nil {
|
||||
return resp, fmt.Errorf("failed to decode the response: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// confirmMobile confirms pairing and returns mobile proxy token.
|
||||
func confirmMobile(connectionToken string) (string, error) {
|
||||
url := getApiURL() + "/mobile/confirm"
|
||||
|
||||
deviceID := uuid.NewString()
|
||||
if deviceIDFromEnv := os.Getenv("TEST_DEVICE_ID"); deviceIDFromEnv != "" {
|
||||
deviceID = deviceIDFromEnv
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, deviceID))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to prepare the reqest: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken))
|
||||
|
||||
httpResp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to perform the reqest: %w", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
bb, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read body from response: %w", err)
|
||||
}
|
||||
|
||||
if httpResp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
ProxyToken string `json:"proxy_token"`
|
||||
}
|
||||
if err := json.Unmarshal(bb, &resp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode the response: %w", err)
|
||||
}
|
||||
|
||||
return resp.ProxyToken, nil
|
||||
}
|
||||
|
||||
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
|
||||
// read exactly one message (and then check it is `expectedReadMsg`).
|
||||
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
|
||||
conn, err := dialWS(url, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
doneReading := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(doneReading)
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
doneReading <- fmt.Errorf("faile to read message: %w", err)
|
||||
return
|
||||
}
|
||||
if string(message) != expectedReadMsg {
|
||||
doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message))
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
err, _ = <-doneReading
|
||||
if err != nil {
|
||||
return fmt.Errorf("error when reading: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dialWS(url, auth string) (*websocket.Conn, error) {
|
||||
authEncodedAsProtocol := fmt.Sprintf("base64url.bearer.authorization.2pass.io.%s", base64.RawURLEncoding.EncodeToString([]byte(auth)))
|
||||
|
||||
conn, _, err := wsDialer.Dial(url, http.Header{
|
||||
"Sec-WebSocket-Protocol": []string{
|
||||
"2pass.io",
|
||||
authEncodedAsProtocol,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func bytesPrintf(format string, ii ...interface{}) io.Reader {
|
||||
s := fmt.Sprintf(format, ii...)
|
||||
return bytes.NewBufferString(s)
|
||||
}
|
83
tests/pass/sync_test.go
Normal file
83
tests/pass/sync_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSyncHappyFlow(t *testing.T) {
|
||||
resp, err := configureBrowserExtension()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to configure browser extension: %v", err)
|
||||
}
|
||||
|
||||
browserExtensionDone := make(chan struct{})
|
||||
mobileParingDone := make(chan struct{})
|
||||
|
||||
fcm := uuid.NewString()
|
||||
|
||||
go func() {
|
||||
defer close(browserExtensionDone)
|
||||
_, syncToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
|
||||
if err != nil {
|
||||
t.Errorf("Error when Browser Extension waited for pairing confirm: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
proxyToken, err := browserExtensionWaitForSyncConfirm(syncToken)
|
||||
if err != nil {
|
||||
t.Errorf("Error when Browser Extension waited for sync confirm: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = proxyWebSocket(
|
||||
getWsURL()+"/browser_extension/sync/proxy",
|
||||
proxyToken,
|
||||
"sent from browser extension",
|
||||
"sent from mobile")
|
||||
if err != nil {
|
||||
t.Errorf("Browser Extension: proxy failed: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(mobileParingDone)
|
||||
|
||||
_, err := confirmMobile(resp.ConnectionToken, fcm)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: confirm failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
confirmToken, err := getMobileToken(fcm)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to fetch mobile token: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
proxyToken, err := confirmSyncMobile(confirmToken)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to confirm mobile: %v", err)
|
||||
return
|
||||
}
|
||||
if proxyToken == "" {
|
||||
t.Errorf("Mobile: proxy token is empty")
|
||||
return
|
||||
}
|
||||
|
||||
err = proxyWebSocket(
|
||||
getWsURL()+"/mobile/sync/proxy",
|
||||
proxyToken,
|
||||
"sent from mobile",
|
||||
"sent from browser extension",
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: proxy failed: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
<-browserExtensionDone
|
||||
<-mobileParingDone
|
||||
}
|
133
tests/pass/ws.go
Normal file
133
tests/pass/ws.go
Normal file
@ -0,0 +1,133 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
wsDialer = websocket.DefaultDialer
|
||||
)
|
||||
|
||||
func getWsURL() string {
|
||||
wsURL := os.Getenv("WS_URL")
|
||||
if wsURL != "" {
|
||||
return wsURL
|
||||
}
|
||||
return "ws://" + getPassAddr()
|
||||
}
|
||||
|
||||
func browserExtensionWaitForSyncConfirm(token string) (string, error) {
|
||||
url := getWsURL() + "/browser_extension/sync/request"
|
||||
|
||||
var resp struct {
|
||||
BrowserExtensionSyncToken string `json:"browser_extension_proxy_token"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
conn, err := dialWS(url, token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(15 * time.Second))
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading from connection: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(message, &resp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
const expectedStatus = "ok"
|
||||
if resp.Status != expectedStatus {
|
||||
return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
|
||||
}
|
||||
return resp.BrowserExtensionSyncToken, nil
|
||||
}
|
||||
|
||||
func browserExtensionWaitForConfirm(token string) (string, string, error) {
|
||||
url := getWsURL() + "/browser_extension/wait_for_connection"
|
||||
|
||||
var resp struct {
|
||||
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
|
||||
BrowserExtensionSyncToken string `json:"browser_extension_sync_token"`
|
||||
Status string `json:"status"`
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
|
||||
conn, err := dialWS(url, token)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("error reading from connection: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(message, &resp); err != nil {
|
||||
return "", "", fmt.Errorf("failed to decode message: %w", err)
|
||||
}
|
||||
const expectedStatus = "ok"
|
||||
if resp.Status != expectedStatus {
|
||||
return "", "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
|
||||
}
|
||||
return resp.BrowserExtensionProxyToken, resp.BrowserExtensionSyncToken, nil
|
||||
}
|
||||
|
||||
func dialWS(url, auth string) (*websocket.Conn, error) {
|
||||
authEncodedAsProtocol := fmt.Sprintf("base64url.bearer.authorization.2pass.io.%s", base64.RawURLEncoding.EncodeToString([]byte(auth)))
|
||||
|
||||
conn, _, err := wsDialer.Dial(url, http.Header{
|
||||
"Sec-WebSocket-Protocol": []string{
|
||||
"2pass.io",
|
||||
authEncodedAsProtocol,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
|
||||
// read exactly one message (and then check it is `expectedReadMsg`).
|
||||
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
|
||||
conn, err := dialWS(url, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
doneReading := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(doneReading)
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
doneReading <- fmt.Errorf("faile to read message: %w", err)
|
||||
return
|
||||
}
|
||||
if string(message) != expectedReadMsg {
|
||||
doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message))
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
err, _ = <-doneReading
|
||||
if err != nil {
|
||||
return fmt.Errorf("error when reading: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user