feat: add sync endpoint to pass

This commit is contained in:
Krzysztof Dryś 2024-02-29 22:55:32 +01:00
parent cab30f73e6
commit e6b1c17809
21 changed files with 1219 additions and 591 deletions

View File

@ -1,8 +1,9 @@
package config
type PassConfig struct {
Addr string `envconfig:"PASS_ADDR" default:":8082"`
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
AWSRegion string `envconfig:"AWS_REGION" default:"us-east-2"`
Addr string `envconfig:"PASS_ADDR" default:":8082"`
KMSKeyID string `envconfig:"KMS_KEY_ID" default:"alias/pass_service_signing_key"`
AWSEndpoint string `envconfig:"AWS_ENDPOINT" default:""`
AWSRegion string `envconfig:"AWS_REGION" default:"us-east-2"`
FakeMobilePush bool `envconfig:"FAKE_MOBILE_PUSH" default:"false"`
}

View File

@ -0,0 +1,149 @@
package connection
import (
"bytes"
"time"
"github.com/gorilla/websocket"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/common/recovery"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 4 * 1048
)
var (
newline = []byte{'\n'}
space = []byte{' '}
acceptedCloseStatus = []int{
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
}
)
// proxy is a responsible for reading from read chan and sending it over wsConn
// and reading fom wsChan and sending it over send chan
type proxy struct {
send chan []byte
read chan []byte
conn *websocket.Conn
}
func StartProxy(wsConn *websocket.Conn, send, read chan []byte) {
proxy := &proxy{
send: send,
read: read,
conn: wsConn,
}
go recovery.DoNotPanic(func() {
proxy.writePump()
})
go recovery.DoNotPanic(func() {
proxy.readPump()
})
go recovery.DoNotPanic(func() {
disconnectAfter := 3 * time.Minute
timeout := time.After(disconnectAfter)
<-timeout
logging.Info("Connection closed after", disconnectAfter)
proxy.conn.Close()
})
}
// readPump pumps messages from the websocket connection to send.
//
// The application runs readPump in a per-connection goroutine. The application
// ensures that there is at most one reader on a connection by executing all
// reads from this goroutine.
func (p *proxy) readPump() {
defer func() {
p.conn.Close()
close(p.send)
}()
p.conn.SetReadLimit(maxMessageSize)
p.conn.SetReadDeadline(time.Now().Add(pongWait))
p.conn.SetPongHandler(func(string) error {
p.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := p.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) {
logging.WithFields(logging.Fields{
"reason": err.Error(),
}).Error("Websocket connection closed unexpected")
} else {
logging.WithFields(logging.Fields{
"reason": err.Error(),
}).Info("Connection closed")
}
break
}
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
p.send <- message
}
}
// writePump pumps messages from the read chan to the websocket connection.
//
// A goroutine running writePump is started for each connection. The
// application ensures that there is at most one writer to a connection by
// executing all writes from this goroutine.
func (p *proxy) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
p.conn.Close()
}()
for {
select {
case message, ok := <-p.read:
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// The hub closed the channel.
p.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := p.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := p.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}

View File

@ -0,0 +1,103 @@
package connection
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/twofas/2fas-server/internal/common/logging"
)
// Proxy between Browser Extension and Mobile.
type Proxy struct {
proxyPool *proxyPool
idLabel string
}
func NewProxy(idLabel string) *Proxy {
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
go func() {
ticker := time.NewTicker(30 * time.Second)
for {
<-ticker.C
proxyPool.deleteExpiresPairs()
}
}()
return &Proxy{
proxyPool: proxyPool,
idLabel: idLabel,
}
}
type proxyPool struct {
mu sync.Mutex
proxies map[string]*proxyPair
}
// registerMobileConn register proxyPair if not existing in pool and returns it.
func (pp *proxyPool) getOrCreateProxyPair(id string) *proxyPair {
pp.mu.Lock()
defer pp.mu.Unlock()
v, ok := pp.proxies[id]
if !ok {
v = initProxyPair()
}
pp.proxies[id] = v
return v
}
func (pp *proxyPool) deleteExpiresPairs() {
pp.mu.Lock()
defer pp.mu.Unlock()
for key, pair := range pp.proxies {
if time.Now().After(pair.expiresAt) {
delete(pp.proxies, key)
}
}
}
type proxyPair struct {
toMobileDataCh chan []byte
toExtensionDataCh chan []byte
expiresAt time.Time
}
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
func initProxyPair() *proxyPair {
const proxyTimeout = 3 * time.Minute
return &proxyPair{
toMobileDataCh: make(chan []byte),
toExtensionDataCh: make(chan []byte),
expiresAt: time.Now().Add(proxyTimeout),
}
}
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, id string) error {
log := logging.WithField(p.idLabel, id)
conn, err := Upgrade(w, r)
if err != nil {
return fmt.Errorf("failed to upgrade connection: %w", err)
}
log.Infof("Starting ServeExtensionProxyToMobileWS")
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
StartProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh)
return nil
}
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, id string) error {
conn, err := Upgrade(w, r)
if err != nil {
return fmt.Errorf("failed to upgrade connection: %w", err)
}
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id)
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
StartProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh)
return nil
}

View File

@ -1,4 +1,4 @@
package pairing
package connection
import (
"encoding/base64"
@ -18,16 +18,16 @@ const bearerProtocolPrefix = "base64url.bearer.authorization.2pass.io."
var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol")
// tokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
// TokenFromWSProtocol returns authorization token from 'Sec-WebSocket-Protocol' request header.
// It is used because websocket API in browser does not allow to pass headers.
// https://github.com/kubernetes/kubernetes/commit/714f97d7baf4975ad3aa47735a868a81a984d1f0
//
// Client provides a bearer token as a subprotocol in the format
// "base64url.bearer.authorization.2pass.io.<base64url-without-padding(bearer-token)>".
// This function also modified request header by removing authorization header.
// Server according to spec must return at least 1 protocol to client, so this fn checks
// Server according to spec must return at least 1 protocol to proxy, so this fn checks
// if at least 1 protocol is sent, beside authorization header.
func tokenFromWSProtocol(req *http.Request) (string, error) {
func TokenFromWSProtocol(req *http.Request) (string, error) {
token := ""
sawTokenProtocol := false
filteredProtocols := []string{}
@ -62,7 +62,7 @@ func tokenFromWSProtocol(req *http.Request) (string, error) {
}
// Must pass at least one other subprotocol so that we can remove the one containing the bearer token,
// and there is at least one to echo back to the client
// and there is at least one to echo back to the proxy
if len(filteredProtocols) == 0 {
return "", errors.New("missing additional subprotocol")
}
@ -91,10 +91,14 @@ var upgrader2pass = websocket.Upgrader{
Subprotocols: []string{supportedProtocol2pass},
}
func wsUpgraderForProtocol(req *http.Request) (websocket.Upgrader, error) {
func Upgrade(w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) {
protocols := strings.Split(req.Header.Get(protocolHeader), ",")
if slices.Contains(protocols, supportedProtocol2pass) {
return upgrader2pass, nil
if !slices.Contains(protocols, supportedProtocol2pass) {
return nil, fmt.Errorf("upgrader not available for protocols: %v", protocols)
}
return websocket.Upgrader{}, fmt.Errorf("upgrader not available for protocols: %v", protocols)
conn, err := upgrader2pass.Upgrade(w, req, nil)
if err != nil {
return nil, fmt.Errorf("failed to upgrade connection: %w", err)
}
return conn, nil
}

View File

@ -1,4 +1,4 @@
package pairing
package connection
import (
"net/http"
@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require"
)
func Test_tokenFromWSProtocol(t *testing.T) {
func Test_TokenFromWSProtocol(t *testing.T) {
tests := []struct {
name string
protocolHeader string
@ -56,7 +56,7 @@ func Test_tokenFromWSProtocol(t *testing.T) {
if protocolHeader != "" {
req.Header.Set(protocolHeader, tt.protocolHeader)
}
got, err := tokenFromWSProtocol(req)
got, err := TokenFromWSProtocol(req)
tt.assertFn(t, got, err)
})
}

View File

@ -9,6 +9,7 @@ import (
"github.com/google/uuid"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection"
)
func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
@ -35,7 +36,7 @@ func ExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := tokenFromWSProtocol(gCtx.Request)
token, err := connection.TokenFromWSProtocol(gCtx.Request)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
@ -45,17 +46,21 @@ func ExtensionWaitForConnWSHandler(pairingApp *Pairing) gin.HandlerFunc {
extensionID, err := pairingApp.VerifyPairingToken(gCtx, token)
if err != nil {
logging.Errorf("Failed to verify pairing token: %v", err)
gCtx.Status(http.StatusInternalServerError)
gCtx.Status(http.StatusUnauthorized)
return
}
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
if err := pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID); err != nil {
logging.Errorf("Failed serve ws: %v", err)
gCtx.Status(http.StatusInternalServerError)
return
}
}
}
func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *connection.Proxy) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := tokenFromWSProtocol(gCtx.Request)
token, err := connection.TokenFromWSProtocol(gCtx.Request)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
@ -77,7 +82,11 @@ func ExtensionProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFu
gCtx.String(http.StatusForbidden, "Pairing is not yet done")
return
}
proxyApp.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, extensionID, pairingInfo.Device.DeviceID)
if err := proxyApp.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID); err != nil {
logging.Errorf("Failed to serve ws: %v", err)
gCtx.Status(http.StatusInternalServerError)
return
}
}
}
@ -116,9 +125,9 @@ func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
}
}
func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
func MobileProxyWSHandler(pairingApp *Pairing, proxy *connection.Proxy) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := tokenFromWSProtocol(gCtx.Request)
token, err := connection.TokenFromWSProtocol(gCtx.Request)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
@ -141,7 +150,11 @@ func MobileProxyWSHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc
gCtx.String(http.StatusForbidden, "Pairing is not yet done")
return
}
proxyApp.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID)
if err := proxy.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID); err != nil {
log.Errorf("Failed to serve ws: %w", err)
gCtx.Status(http.StatusInternalServerError)
return
}
}
}

View File

@ -8,7 +8,9 @@ import (
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection"
"github.com/twofas/2fas-server/internal/pass/sign"
)
@ -77,31 +79,25 @@ type ExtensionWaitForConnectionInput struct {
type WaitForConnectionResponse struct {
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
BrowserExtensionSyncToken string `json:"browser_extension_sync_token"`
Status string `json:"status"`
DeviceID string `json:"device_id"`
}
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) error {
log := logging.WithField("extension_id", extID)
upgrader, err := wsUpgraderForProtocol(r)
conn, err := connection.Upgrade(w, r)
if err != nil {
log.Error(err)
return
return fmt.Errorf("failed to upgrade connection: %w", err)
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
return
}
defer conn.Close()
log.Info("Starting pairing WS")
if deviceID, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
if err := p.sendTokenAndCloseConn(extID, deviceID, conn); err != nil {
if pairing, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
if err := p.sendTokenAndCloseConn(extID, pairing, conn); err != nil {
log.Errorf("Failed to send token: %v", err)
}
return
return nil
}
const (
@ -116,30 +112,30 @@ func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID s
select {
case <-maxWaitC:
log.Info("Closing paring ws after timeout")
return
return nil
case <-connectedCheckTicker.C:
if deviceID, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
if err := p.sendTokenAndCloseConn(extID, deviceID, conn); err != nil {
if pairing, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
if err := p.sendTokenAndCloseConn(extID, pairing, conn); err != nil {
log.Errorf("Failed to send token: %v", err)
return
return nil
}
log.WithField("device_id", deviceID).Infof("Paring ws finished")
return
log.WithField("device_id", pairing.Device.DeviceID).Infof("Paring ws finished")
return nil
}
}
}
}
func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logrus.Entry) (string, bool) {
func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logrus.Entry) (PairingInfo, bool) {
pairingInfo, err := p.store.GetPairingInfo(ctx, extID)
if err != nil {
log.Warn("Failed to get pairing info")
return "", false
return PairingInfo{}, false
}
return pairingInfo.Device.DeviceID, pairingInfo.IsPaired()
return pairingInfo, pairingInfo.IsPaired()
}
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
func (p *Pairing) sendTokenAndCloseConn(extID string, pairingInfo PairingInfo, conn *websocket.Conn) error {
extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: extID,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
@ -148,11 +144,20 @@ func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.
if err != nil {
return fmt.Errorf("Failed to generate ext proxy token: %v", err)
}
var syncToken string
if pairingInfo.Device.FCMToken != "" {
syncToken, err = p.signSvc.SignAndEncode(sign.Message{
ConnectionID: pairingInfo.Device.FCMToken,
ExpiresAt: time.Now().AddDate(1, 0, 0),
ConnectionType: sign.ConnectionTypeBrowserExtensionSyncRequest,
})
}
if err := conn.WriteJSON(WaitForConnectionResponse{
BrowserExtensionProxyToken: extProxyToken,
BrowserExtensionSyncToken: syncToken,
Status: "ok",
DeviceID: deviceID,
DeviceID: pairingInfo.Device.DeviceID,
}); err != nil {
return fmt.Errorf("failed to write to extension: %v", err)
}
@ -191,5 +196,6 @@ func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest,
}); err != nil {
return ConfirmPairingResponse{}, err
}
return ConfirmPairingResponse{ProxyToken: mobileProxyToken}, nil
}

View File

@ -1,269 +0,0 @@
package pairing
import (
"bytes"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/common/recovery"
)
type Proxy struct {
proxyPool *proxyPool
}
func NewProxy() *Proxy {
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
go func() {
ticker := time.NewTicker(30 * time.Second)
for {
<-ticker.C
proxyPool.deleteExpiresPairs()
}
}()
return &Proxy{
proxyPool: proxyPool,
}
}
type proxyPool struct {
mu sync.Mutex
proxies map[string]*proxyPair
}
// registerMobileConn register proxyPair if not existing in pool and returns it.
func (pp *proxyPool) getOrCreateProxyPair(deviceID string) *proxyPair {
// TODO: handle delete.
// TODO: right now two connections to the same WS results in race for messages/ decide if we want multiple conn or not.
pp.mu.Lock()
defer pp.mu.Unlock()
v, ok := pp.proxies[deviceID]
if !ok {
v = initProxyPair()
}
pp.proxies[deviceID] = v
return v
}
func (pp *proxyPool) deleteExpiresPairs() {
pp.mu.Lock()
defer pp.mu.Unlock()
for key, pair := range pp.proxies {
if time.Now().After(pair.expiresAt) {
delete(pp.proxies, key)
}
}
}
type proxyPair struct {
toMobileDataCh chan []byte
toExtensionDataCh chan []byte
expiresAt time.Time
}
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
func initProxyPair() *proxyPair {
const proxyTimeout = 3 * time.Minute
return &proxyPair{
toMobileDataCh: make(chan []byte),
toExtensionDataCh: make(chan []byte),
expiresAt: time.Now().Add(proxyTimeout),
}
}
var (
newline = []byte{'\n'}
space = []byte{' '}
acceptedCloseStatus = []int{
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
websocket.CloseAbnormalClosure,
}
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer.
maxMessageSize = 4 * 1048
)
// client is a responsible for reading from read chan and sending it over wsConn
// and reading fom wsChan and sending it over send chan
type client struct {
send chan []byte
read chan []byte
conn *websocket.Conn
}
func newClient(wsConn *websocket.Conn, send, read chan []byte) *client {
return &client{
send: send,
read: read,
conn: wsConn,
}
}
// readPump pumps messages from the websocket connection to send.
//
// The application runs readPump in a per-connection goroutine. The application
// ensures that there is at most one reader on a connection by executing all
// reads from this goroutine.
func (c *client) readPump() {
defer func() {
c.conn.Close()
close(c.send)
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) {
logging.WithFields(logging.Fields{
"reason": err.Error(),
}).Error("Websocket connection closed unexpected")
} else {
logging.WithFields(logging.Fields{
"reason": err.Error(),
}).Info("Connection closed")
}
break
}
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
c.send <- message
}
}
// writePump pumps messages from the read chan to the websocket connection.
//
// A goroutine running writePump is started for each connection. The
// application ensures that there is at most one writer to a connection by
// executing all writes from this goroutine.
func (c *client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.read:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// The hub closed the channel.
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
upgrader, err := wsUpgraderForProtocol(r)
if err != nil {
log.Error(err)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
return
}
log.Infof("Starting ServeExtensionProxyToMobileWS")
proxyPair := p.proxyPool.getOrCreateProxyPair(deviceID)
client := newClient(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh)
go recovery.DoNotPanic(func() {
client.writePump()
})
go recovery.DoNotPanic(func() {
client.readPump()
})
go recovery.DoNotPanic(func() {
disconnectAfter := 3 * time.Minute
timeout := time.After(disconnectAfter)
<-timeout
logging.Info("Connection closed after", disconnectAfter)
client.conn.Close()
})
}
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
upgrader, err := wsUpgraderForProtocol(r)
if err != nil {
logging.Error(err)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)
return
}
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", deviceID)
proxyPair := p.proxyPool.getOrCreateProxyPair(deviceID)
client := newClient(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh)
go recovery.DoNotPanic(func() {
client.writePump()
})
go recovery.DoNotPanic(func() {
client.readPump()
})
go recovery.DoNotPanic(func() {
disconnectAfter := 3 * time.Minute
timeout := time.After(disconnectAfter)
<-timeout
logging.Info("Connection closed after", disconnectAfter)
client.conn.Close()
})
}

View File

@ -7,12 +7,15 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/pass/sign"
"github.com/twofas/2fas-server/config"
httphelpers "github.com/twofas/2fas-server/internal/common/http"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/common/recovery"
"github.com/twofas/2fas-server/internal/pass/connection"
"github.com/twofas/2fas-server/internal/pass/pairing"
"github.com/twofas/2fas-server/internal/pass/sign"
"github.com/twofas/2fas-server/internal/pass/sync"
)
type Server struct {
@ -49,7 +52,10 @@ func NewServer(cfg config.PassConfig) *Server {
}
pairingApp := pairing.NewPairingApp(signSvc)
proxyApp := pairing.NewProxy()
proxyPairingApp := connection.NewProxy("fcm_token")
syncApp := sync.NewPairingApp(signSvc, cfg.FakeMobilePush)
proxySyncApp := connection.NewProxy("device_id")
router := gin.New()
router.Use(recovery.RecoveryMiddleware())
@ -64,9 +70,24 @@ func NewServer(cfg config.PassConfig) *Server {
router.POST("/browser_extension/configure", pairing.ExtensionConfigureHandler(pairingApp))
router.GET("/browser_extension/wait_for_connection", pairing.ExtensionWaitForConnWSHandler(pairingApp))
router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyApp))
router.GET("/browser_extension/proxy_to_mobile", pairing.ExtensionProxyWSHandler(pairingApp, proxyPairingApp))
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyApp))
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyWSHandler(pairingApp, proxyPairingApp))
router.GET("/browser_extension/pairing/wait", pairing.ExtensionWaitForConnWSHandler(pairingApp))
router.GET("/browser_extension/pairing/proxy", pairing.ExtensionProxyWSHandler(pairingApp, proxyPairingApp))
router.POST("/mobile/pairing/confirm", pairing.MobileConfirmHandler(pairingApp))
router.GET("/mobile/pairing/proxy", pairing.MobileProxyWSHandler(pairingApp, proxyPairingApp))
router.GET("/browser_extension/sync/request", sync.ExtensionRequestSync(syncApp))
router.GET("/browser_extension/sync/proxy", sync.ExtensionProxyWSHandler(syncApp, proxySyncApp))
router.POST("/mobile/sync/confirm", sync.MobileConfirmHandler(syncApp))
router.GET("/mobile/sync/proxy", sync.MobileProxyWSHandler(syncApp, proxySyncApp))
if cfg.FakeMobilePush {
logging.Info("Enabled '/mobile/sync/:fcm/token' endpoint. This should happen in test env only!")
router.GET("/mobile/sync/:fcm/token", sync.MobileGenerateSyncToken(syncApp))
}
return &Server{
router: router,

View File

@ -56,8 +56,12 @@ func NewService(keyID string, client *kms.KMS) (*Service, error) {
type ConnectionType string
const (
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
ConnectionTypeBrowserExtensionSyncRequest ConnectionType = "be/sync/request"
ConnectionTypeBrowserExtensionSync ConnectionType = "be/sync/proxy"
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
ConnectionTypeMobileConfirm ConnectionType = "mobile/confirm"
ConnectionTypeMobileSyncConfirm ConnectionType = "mobile/sync/confirm"
ConnectionTypeMobileSyncProxy ConnectionType = "mobile/sync/proxy"
)

View File

@ -0,0 +1,44 @@
package sync
import (
"context"
"fmt"
"github.com/twofas/2fas-server/internal/pass/sign"
)
// VerifyExtRequestSyncToken verifies sync request token and returns fcm_token.
func (p *Syncing) VerifyExtRequestSyncToken(ctx context.Context, proxyToken string) (string, error) {
fcmToken, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionSyncRequest)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
return fcmToken, nil
}
// VerifyExtSyncToken verifies sync token and returns fcm_token.
func (p *Syncing) VerifyExtSyncToken(ctx context.Context, proxyToken string) (string, error) {
fcmToken, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeBrowserExtensionSync)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
return fcmToken, nil
}
// VerifyMobileSyncConfirmToken verifies mobile token and returns connection id.
func (p *Syncing) VerifyMobileSyncConfirmToken(ctx context.Context, proxyToken string) (string, error) {
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileSyncConfirm)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
return extensionID, nil
}
// VerifyMobileSyncProxyToken verifies mobile token and returns connection id.
func (p *Syncing) VerifyMobileSyncProxyToken(ctx context.Context, proxyToken string) (string, error) {
extensionID, err := p.signSvc.CanI(proxyToken, sign.ConnectionTypeMobileSyncProxy)
if err != nil {
return "", fmt.Errorf("failed to check token signature: %w", err)
}
return extensionID, nil
}

View File

@ -0,0 +1,141 @@
package sync
import (
"errors"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection"
)
func ExtensionRequestSync(syncingApp *Syncing) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := connection.TokenFromWSProtocol(gCtx.Request)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
return
}
fcmToken, err := syncingApp.VerifyExtRequestSyncToken(gCtx, token)
if err != nil {
logging.Errorf("Failed to verify proxy token: %v", err)
gCtx.String(http.StatusUnauthorized, "Invalid auth token")
return
}
if err := syncingApp.ServeSyncingRequestWS(gCtx.Writer, gCtx.Request, fcmToken); err != nil {
logging.Errorf("Failed to verify proxy token: %v", err)
gCtx.Status(http.StatusInternalServerError)
return
}
}
}
func ExtensionProxyWSHandler(syncingApp *Syncing, proxy *connection.Proxy) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := connection.TokenFromWSProtocol(gCtx.Request)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
return
}
fcmToken, err := syncingApp.VerifyExtSyncToken(gCtx, token)
if err != nil {
logging.Errorf("Failed to verify proxy token: %v", err)
gCtx.Status(http.StatusInternalServerError)
return
}
ok := syncingApp.isSyncConfirmed(gCtx, fcmToken)
if !ok {
gCtx.String(http.StatusForbidden, "Syncing is not yet done")
return
}
if err := proxy.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, fcmToken); err != nil {
logging.Errorf("Failed to serve ws: %v", err)
gCtx.Status(http.StatusInternalServerError)
}
}
}
func MobileConfirmHandler(syncApp *Syncing) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := tokenFromRequest(gCtx)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
return
}
fcmToken, err := syncApp.VerifyMobileSyncConfirmToken(gCtx, token)
if err != nil {
logging.Errorf("Failed to verify connection token: %v", err)
gCtx.Status(http.StatusUnauthorized)
return
}
resp, err := syncApp.confirmSync(gCtx, fcmToken)
if err != nil {
if errors.Is(err, noSyncRequestErr) {
gCtx.String(http.StatusBadRequest, "no sync request was created for this token")
}
logging.Errorf("Failed to ConfirmPairing: %v", err)
gCtx.Status(http.StatusInternalServerError)
return
}
gCtx.JSON(http.StatusOK, resp)
}
}
func MobileProxyWSHandler(syncingApp *Syncing, proxy *connection.Proxy) gin.HandlerFunc {
return func(gCtx *gin.Context) {
token, err := connection.TokenFromWSProtocol(gCtx.Request)
if err != nil {
logging.Errorf("Failed to get token from request: %v", err)
gCtx.Status(http.StatusForbidden)
return
}
fcmToken, err := syncingApp.VerifyMobileSyncConfirmToken(gCtx, token)
if err != nil {
logging.Errorf("Invalid connection token: %v", err)
gCtx.Status(http.StatusUnauthorized)
return
}
ok := syncingApp.isSyncConfirmed(gCtx, fcmToken)
if !ok {
gCtx.String(http.StatusForbidden, "Syncing is not yet done")
return
}
if err := proxy.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, fcmToken); err != nil {
logging.Errorf("Failed to serve ws: %v", err)
gCtx.Status(http.StatusInternalServerError)
}
}
}
func MobileGenerateSyncToken(syncingApp *Syncing) gin.HandlerFunc {
return func(gCtx *gin.Context) {
fcm := gCtx.Param("fcm")
if fcm == "" {
gCtx.Status(http.StatusBadRequest)
return
}
if err := syncingApp.sendMobileToken(fcm, gCtx.Writer); err != nil {
logging.Errorf("Failed to send mobile token: %v", err)
gCtx.Status(http.StatusInternalServerError)
}
}
}
func tokenFromRequest(gCtx *gin.Context) (string, error) {
tokenHeader := gCtx.GetHeader("Authorization")
if tokenHeader == "" {
return "", errors.New("missing Authorization header")
}
splitToken := strings.Split(tokenHeader, "Bearer ")
if len(splitToken) != 2 {
gCtx.Status(http.StatusForbidden)
return "", errors.New("missing 'Bearer: value'")
}
return splitToken[1], nil
}

View File

@ -0,0 +1,61 @@
package sync
import (
"sync"
"time"
)
// MemoryStore keeps in memory pairing between extension and mobile.
//
// TODO: check ttlcache pkg, right now entries are not invalidated.
type MemoryStore struct {
mu sync.Mutex
extensionsMap map[string]Item
}
type Item struct {
FCMToken string
Expires time.Time
Confirmed bool
}
func NewMemoryStore() *MemoryStore {
return &MemoryStore{
extensionsMap: make(map[string]Item),
}
}
func (s *MemoryStore) RequestSync(fcmToken string) {
s.setItem(fcmToken, Item{FCMToken: fcmToken})
}
func (s *MemoryStore) ConfirmSync(fcmToken string) bool {
v, ok := s.getItem(fcmToken)
if !ok {
return false
}
v.Confirmed = true
s.setItem(fcmToken, v)
return true
}
func (s *MemoryStore) IsSyncCofirmed(fcmToken string) bool {
v, ok := s.getItem(fcmToken)
if !ok {
return false
}
return v.Confirmed
}
func (s *MemoryStore) setItem(key string, item Item) {
s.mu.Lock()
defer s.mu.Unlock()
s.extensionsMap[key] = item
}
func (s *MemoryStore) getItem(key string) (Item, bool) {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.extensionsMap[key]
return v, ok
}

178
internal/pass/sync/sync.go Normal file
View File

@ -0,0 +1,178 @@
package sync
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection"
"github.com/twofas/2fas-server/internal/pass/sign"
)
type Syncing struct {
store store
signSvc *sign.Service
}
type store interface {
RequestSync(fmtToken string)
ConfirmSync(fmtToken string) bool
IsSyncCofirmed(fmtToken string) bool
}
func NewPairingApp(signService *sign.Service, fakeMobilePush bool) *Syncing {
return &Syncing{
store: NewMemoryStore(),
signSvc: signService,
}
}
const (
pairingTokenValidityDuration = 3 * time.Minute
)
type ConfigureBrowserExtensionRequest struct {
ExtensionID string `json:"extension_id"`
}
type ConfigureBrowserExtensionResponse struct {
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
ConnectionToken string `json:"connection_token"`
}
type ExtensionWaitForConnectionInput struct {
ResponseWriter http.ResponseWriter
HttpReq *http.Request
}
type RequestSyncResponse struct {
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
Status string `json:"status"`
}
type MobileSyncPayload struct {
MobileSyncToken string `json:"mobile_sync_token"`
}
func (p *Syncing) ServeSyncingRequestWS(w http.ResponseWriter, r *http.Request, fcmToken string) error {
log := logging.WithField("fcm_token", fcmToken)
conn, err := connection.Upgrade(w, r)
if err != nil {
return fmt.Errorf("failed to upgrade connection: %w", err)
}
defer conn.Close()
log.Infof("Starting sync request WS for %q", fcmToken)
p.requestSync(r.Context(), fcmToken)
if syncDone := p.isSyncConfirmed(r.Context(), fcmToken); syncDone {
if err := p.sendTokenAndCloseConn(fcmToken, conn); err != nil {
log.Errorf("Failed to send token: %v", err)
}
log.Infof("Paring ws finished")
return nil
}
const (
maxWaitTime = 3 * time.Minute
checkIfConnectedInterval = time.Second
)
maxWaitC := time.After(maxWaitTime)
connectedCheckTicker := time.NewTicker(checkIfConnectedInterval)
defer connectedCheckTicker.Stop()
for {
select {
case <-maxWaitC:
log.Info("Closing paring ws after timeout")
return nil
case <-connectedCheckTicker.C:
if syncConfirmed := p.isSyncConfirmed(r.Context(), fcmToken); syncConfirmed {
if err := p.sendTokenAndCloseConn(fcmToken, conn); err != nil {
log.Errorf("Failed to send token: %v", err)
return nil
}
log.Infof("Paring ws finished")
return nil
}
}
}
}
func (p *Syncing) isSyncConfirmed(ctx context.Context, fcmToken string) bool {
return p.store.IsSyncCofirmed(fcmToken)
}
func (p *Syncing) requestSync(ctx context.Context, fcmToken string) {
p.store.RequestSync(fcmToken)
}
func (p *Syncing) sendTokenAndCloseConn(fcmToken string, conn *websocket.Conn) error {
extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: fcmToken,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeBrowserExtensionSync,
})
if err != nil {
return fmt.Errorf("failed to generate ext proxy token: %v", err)
}
if err := conn.WriteJSON(RequestSyncResponse{
BrowserExtensionProxyToken: extProxyToken,
Status: "ok",
}); err != nil {
return fmt.Errorf("failed to write to extension: %v", err)
}
return conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
}
func (p *Syncing) sendMobileToken(fcmToken string, resp http.ResponseWriter) error {
extProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: fcmToken,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeMobileSyncConfirm,
})
if err != nil {
return fmt.Errorf("failed to generate ext proxy token: %v", err)
}
bb, err := json.Marshal(struct {
MobileSyncConfirmToken string `json:"mobile_sync_confirm_token"`
}{
MobileSyncConfirmToken: extProxyToken,
})
if err != nil {
return fmt.Errorf("failed to write to extension: %v", err)
}
resp.Write(bb)
return nil
}
type ConfirmSyncResponse struct {
ProxyToken string `json:"proxy_token"`
}
var noSyncRequestErr = errors.New("sync request was not created")
func (p *Syncing) confirmSync(ctx context.Context, fcmToken string) (ConfirmSyncResponse, error) {
logging.Infof("Starting sync confirm for %q", fcmToken)
mobileProxyToken, err := p.signSvc.SignAndEncode(sign.Message{
ConnectionID: fcmToken,
ExpiresAt: time.Now().Add(pairingTokenValidityDuration),
ConnectionType: sign.ConnectionTypeMobileSyncConfirm,
})
if err != nil {
return ConfirmSyncResponse{}, fmt.Errorf("failed to generate ext proxy token: %v", err)
}
if ok := p.store.ConfirmSync(fcmToken); !ok {
return ConfirmSyncResponse{}, noSyncRequestErr
}
return ConfirmSyncResponse{ProxyToken: mobileProxyToken}, nil
}

View File

@ -3,7 +3,7 @@
apt install --assume-yes jq
AWS_REGION=us-east-1
KEY_ALIAS=pass_service
KEY_ALIAS=pass_service_signing_key
response=$(awslocal kms create-key \
--region $AWS_REGION \

158
tests/pass/http.go Normal file
View File

@ -0,0 +1,158 @@
package pass
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/avast/retry-go/v4"
"github.com/google/uuid"
)
var (
httpClient = http.DefaultClient
)
func getApiURL() string {
apiURL := os.Getenv("API_URL")
if apiURL != "" {
return apiURL
}
return "http://" + getPassAddr()
}
func getPassAddr() string {
addr := os.Getenv("PASS_ADDR")
if addr != "" {
return addr
}
return "localhost:8082"
}
type ConfigureBrowserExtensionResponse struct {
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
ConnectionToken string `json:"connection_token"`
}
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
extensionID := uuid.NewString()
if extensionIDFromEnv := os.Getenv("TEST_EXTENSION_ID"); extensionIDFromEnv != "" {
extensionID = extensionIDFromEnv
}
req := struct {
ExtensionID string `json:"extension_id"`
}{
ExtensionID: extensionID,
}
var resp ConfigureBrowserExtensionResponse
if err := request("POST", "/browser_extension/configure", "", req, &resp); err != nil {
return resp, fmt.Errorf("failed to configure browser: %w", err)
}
return resp, nil
}
// confirmMobile confirms pairing and returns mobile proxy token.
func confirmMobile(connectionToken, fcm string) (string, error) {
deviceID := uuid.NewString()
if deviceIDFromEnv := os.Getenv("TEST_DEVICE_ID"); deviceIDFromEnv != "" {
deviceID = deviceIDFromEnv
}
req := struct {
DeviceID string `json:"device_id"`
FCMToken string `json:"fcm_token"`
}{
DeviceID: deviceID,
FCMToken: fcm,
}
resp := struct {
ProxyToken string `json:"proxy_token"`
}{}
if err := request("POST", "/mobile/confirm", connectionToken, req, &resp); err != nil {
return "", fmt.Errorf("failed to configure browser: %w", err)
}
return resp.ProxyToken, nil
}
// confirmSyncMobile confirms pairing and returns mobile proxy token.
func confirmSyncMobile(connectionToken string) (string, error) {
var result string
err := retry.Do(func() error {
var err error
result, err = confirmSyncMobileRequest(connectionToken)
return err
})
return result, err
}
func confirmSyncMobileRequest(connectionToken string) (string, error) {
var resp struct {
ProxyToken string `json:"proxy_token"`
}
if err := request("POST", "/mobile/sync/confirm", connectionToken, nil, &resp); err != nil {
return "", fmt.Errorf("failed to confirm mobile: %w", err)
}
return resp.ProxyToken, nil
}
func getMobileToken(fcm string) (string, error) {
var resp struct {
MobileSyncConfirmToken string `json:"mobile_sync_confirm_token"`
}
if err := request("GET", fmt.Sprintf("/mobile/sync/%s/token", fcm), "", nil, &resp); err != nil {
return "", fmt.Errorf("failed to get mobile token")
}
return resp.MobileSyncConfirmToken, nil
}
func request(method, path, auth string, req, resp interface{}) error {
url := getApiURL() + path
var body io.Reader
if req != nil {
bb, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to request marshal: %w", err)
}
body = bytes.NewBuffer(bb)
}
httpReq, err := http.NewRequest(method, url, body)
if err != nil {
return fmt.Errorf("failed to create http request: %w", err)
}
if auth != "" {
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", auth))
}
httpResp, err := httpClient.Do(httpReq)
if err != nil {
return fmt.Errorf("failed perform the request: %w", err)
}
defer httpResp.Body.Close()
bb, err := io.ReadAll(httpResp.Body)
if err != nil {
return fmt.Errorf("failed to read body from response: %w", err)
}
if httpResp.StatusCode >= 300 {
return fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
}
if err := json.Unmarshal(bb, &resp); err != nil {
return fmt.Errorf("failed to decode the response: %w", err)
}
return nil
}

View File

@ -26,7 +26,7 @@ func TestSignAndVerifyHappyPath(t *testing.T) {
t.Fatal(err)
}
kmsClient := kms.New(sess)
srv, err := sign.NewService("alias/pass_service", kmsClient)
srv, err := sign.NewService("alias/pass_service_signing_key", kmsClient)
if err != nil {
t.Fatal(err)
}
@ -61,7 +61,7 @@ func TestSignAndVerify(t *testing.T) {
t.Fatal(err)
}
kmsClient := kms.New(sess)
srv, err := sign.NewService("alias/pass_service", kmsClient)
srv, err := sign.NewService("alias/pass_service_signing_key", kmsClient)
if err != nil {
t.Fatal(err)
}

60
tests/pass/pair_test.go Normal file
View File

@ -0,0 +1,60 @@
package pass
import (
"testing"
"github.com/google/uuid"
)
func TestPairHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension()
if err != nil {
t.Fatalf("Failed to configure browser extension: %v", err)
}
browserExtensionDone := make(chan struct{})
mobileDone := make(chan struct{})
go func() {
defer close(browserExtensionDone)
extProxyToken, _, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
if err != nil {
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
return
}
err = proxyWebSocket(
getWsURL()+"/browser_extension/proxy_to_mobile",
extProxyToken,
"sent from browser extension",
"sent from mobile")
if err != nil {
t.Errorf("Browser Extension: proxy failed: %v", err)
return
}
}()
go func() {
defer close(mobileDone)
mobileProxyToken, err := confirmMobile(resp.ConnectionToken, uuid.NewString())
if err != nil {
t.Errorf("Mobile: confirm failed: %v", err)
return
}
err = proxyWebSocket(
getWsURL()+"/mobile/proxy_to_browser_extension",
mobileProxyToken,
"sent from mobile",
"sent from browser extension",
)
if err != nil {
t.Errorf("Mobile: proxy failed: %v", err)
return
}
}()
<-browserExtensionDone
<-mobileDone
}

View File

@ -1,262 +0,0 @@
package pass
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"testing"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type ConfigureBrowserExtensionResponse struct {
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
ConnectionToken string `json:"connection_token"`
}
var (
httpClient = http.DefaultClient
wsDialer = websocket.DefaultDialer
)
func getApiURL() string {
apiURL := os.Getenv("API_URL")
if apiURL != "" {
return apiURL
}
return "http://" + getPassAddr()
}
func getWsURL() string {
wsURL := os.Getenv("WS_URL")
if wsURL != "" {
return wsURL
}
return "ws://" + getPassAddr()
}
func getPassAddr() string {
addr := os.Getenv("PASS_ADDR")
if addr != "" {
return addr
}
return "localhost:8082"
}
func TestPassHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension()
if err != nil {
t.Fatalf("Failed to configure browser extension: %v", err)
}
browserExtensionDone := make(chan struct{})
mobileDone := make(chan struct{})
go func() {
defer close(browserExtensionDone)
extProxyToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
if err != nil {
t.Errorf("Error when Browser Extension waited for confirm: %v", err)
return
}
err = proxyWebSocket(
getWsURL()+"/browser_extension/proxy_to_mobile",
extProxyToken,
"sent from browser extension",
"sent from mobile")
if err != nil {
t.Errorf("Browser Extension: proxy failed: %v", err)
return
}
}()
go func() {
defer close(mobileDone)
mobileProxyToken, err := confirmMobile(resp.ConnectionToken)
if err != nil {
t.Errorf("Mobile: confirm failed: %v", err)
return
}
err = proxyWebSocket(
getWsURL()+"/mobile/proxy_to_browser_extension",
mobileProxyToken,
"sent from mobile",
"sent from browser extension",
)
if err != nil {
t.Errorf("Mobile: proxy failed: %v", err)
return
}
}()
<-browserExtensionDone
<-mobileDone
}
func browserExtensionWaitForConfirm(token string) (string, error) {
url := getWsURL() + "/browser_extension/wait_for_connection"
var resp struct {
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
Status string `json:"status"`
DeviceID string `json:"device_id"`
}
conn, err := dialWS(url, token)
if err != nil {
return "", err
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
return "", fmt.Errorf("error reading from connection: %w", err)
}
if err := json.Unmarshal(message, &resp); err != nil {
return "", fmt.Errorf("failed to decode message: %w", err)
}
const expectedStatus = "ok"
if resp.Status != expectedStatus {
return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
}
return resp.BrowserExtensionProxyToken, nil
}
func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) {
url := getApiURL() + "/browser_extension/configure"
extensionID := uuid.NewString()
if extensionIDFromEnv := os.Getenv("TEST_EXTENSION_ID"); extensionIDFromEnv != "" {
extensionID = extensionIDFromEnv
}
req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, extensionID))
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to create http request: %w", err)
}
httpResp, err := httpClient.Do(req)
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed perform the request: %w", err)
}
defer httpResp.Body.Close()
bb, err := io.ReadAll(httpResp.Body)
if err != nil {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to read body from response: %w", err)
}
if httpResp.StatusCode >= 300 {
return ConfigureBrowserExtensionResponse{}, fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
}
var resp ConfigureBrowserExtensionResponse
if err := json.Unmarshal(bb, &resp); err != nil {
return resp, fmt.Errorf("failed to decode the response: %w", err)
}
return resp, nil
}
// confirmMobile confirms pairing and returns mobile proxy token.
func confirmMobile(connectionToken string) (string, error) {
url := getApiURL() + "/mobile/confirm"
deviceID := uuid.NewString()
if deviceIDFromEnv := os.Getenv("TEST_DEVICE_ID"); deviceIDFromEnv != "" {
deviceID = deviceIDFromEnv
}
req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, deviceID))
if err != nil {
return "", fmt.Errorf("failed to prepare the reqest: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken))
httpResp, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to perform the reqest: %w", err)
}
defer httpResp.Body.Close()
bb, err := io.ReadAll(httpResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read body from response: %w", err)
}
if httpResp.StatusCode >= 300 {
return "", fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb))
}
var resp struct {
ProxyToken string `json:"proxy_token"`
}
if err := json.Unmarshal(bb, &resp); err != nil {
return "", fmt.Errorf("failed to decode the response: %w", err)
}
return resp.ProxyToken, nil
}
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
// read exactly one message (and then check it is `expectedReadMsg`).
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
conn, err := dialWS(url, token)
if err != nil {
return err
}
defer conn.Close()
doneReading := make(chan error)
go func() {
defer close(doneReading)
_, message, err := conn.ReadMessage()
if err != nil {
doneReading <- fmt.Errorf("faile to read message: %w", err)
return
}
if string(message) != expectedReadMsg {
doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message))
return
}
}()
if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
err, _ = <-doneReading
if err != nil {
return fmt.Errorf("error when reading: %w", err)
}
return nil
}
func dialWS(url, auth string) (*websocket.Conn, error) {
authEncodedAsProtocol := fmt.Sprintf("base64url.bearer.authorization.2pass.io.%s", base64.RawURLEncoding.EncodeToString([]byte(auth)))
conn, _, err := wsDialer.Dial(url, http.Header{
"Sec-WebSocket-Protocol": []string{
"2pass.io",
authEncodedAsProtocol,
},
})
if err != nil {
return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
}
return conn, nil
}
func bytesPrintf(format string, ii ...interface{}) io.Reader {
s := fmt.Sprintf(format, ii...)
return bytes.NewBufferString(s)
}

83
tests/pass/sync_test.go Normal file
View File

@ -0,0 +1,83 @@
package pass
import (
"testing"
"github.com/google/uuid"
)
func TestSyncHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension()
if err != nil {
t.Fatalf("Failed to configure browser extension: %v", err)
}
browserExtensionDone := make(chan struct{})
mobileParingDone := make(chan struct{})
fcm := uuid.NewString()
go func() {
defer close(browserExtensionDone)
_, syncToken, err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken)
if err != nil {
t.Errorf("Error when Browser Extension waited for pairing confirm: %v", err)
return
}
proxyToken, err := browserExtensionWaitForSyncConfirm(syncToken)
if err != nil {
t.Errorf("Error when Browser Extension waited for sync confirm: %v", err)
return
}
err = proxyWebSocket(
getWsURL()+"/browser_extension/sync/proxy",
proxyToken,
"sent from browser extension",
"sent from mobile")
if err != nil {
t.Errorf("Browser Extension: proxy failed: %v", err)
return
}
}()
go func() {
defer close(mobileParingDone)
_, err := confirmMobile(resp.ConnectionToken, fcm)
if err != nil {
t.Errorf("Mobile: confirm failed: %v", err)
return
}
confirmToken, err := getMobileToken(fcm)
if err != nil {
t.Errorf("Failed to fetch mobile token: %v", err)
return
}
proxyToken, err := confirmSyncMobile(confirmToken)
if err != nil {
t.Errorf("Failed to confirm mobile: %v", err)
return
}
if proxyToken == "" {
t.Errorf("Mobile: proxy token is empty")
return
}
err = proxyWebSocket(
getWsURL()+"/mobile/sync/proxy",
proxyToken,
"sent from mobile",
"sent from browser extension",
)
if err != nil {
t.Errorf("Mobile: proxy failed: %v", err)
return
}
}()
<-browserExtensionDone
<-mobileParingDone
}

133
tests/pass/ws.go Normal file
View File

@ -0,0 +1,133 @@
package pass
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"os"
"time"
"github.com/gorilla/websocket"
)
var (
wsDialer = websocket.DefaultDialer
)
func getWsURL() string {
wsURL := os.Getenv("WS_URL")
if wsURL != "" {
return wsURL
}
return "ws://" + getPassAddr()
}
func browserExtensionWaitForSyncConfirm(token string) (string, error) {
url := getWsURL() + "/browser_extension/sync/request"
var resp struct {
BrowserExtensionSyncToken string `json:"browser_extension_proxy_token"`
Status string `json:"status"`
}
conn, err := dialWS(url, token)
if err != nil {
return "", err
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(15 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
return "", fmt.Errorf("error reading from connection: %w", err)
}
if err := json.Unmarshal(message, &resp); err != nil {
return "", fmt.Errorf("failed to decode message: %w", err)
}
const expectedStatus = "ok"
if resp.Status != expectedStatus {
return "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
}
return resp.BrowserExtensionSyncToken, nil
}
func browserExtensionWaitForConfirm(token string) (string, string, error) {
url := getWsURL() + "/browser_extension/wait_for_connection"
var resp struct {
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
BrowserExtensionSyncToken string `json:"browser_extension_sync_token"`
Status string `json:"status"`
DeviceID string `json:"device_id"`
}
conn, err := dialWS(url, token)
if err != nil {
return "", "", err
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
return "", "", fmt.Errorf("error reading from connection: %w", err)
}
if err := json.Unmarshal(message, &resp); err != nil {
return "", "", fmt.Errorf("failed to decode message: %w", err)
}
const expectedStatus = "ok"
if resp.Status != expectedStatus {
return "", "", fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus)
}
return resp.BrowserExtensionProxyToken, resp.BrowserExtensionSyncToken, nil
}
func dialWS(url, auth string) (*websocket.Conn, error) {
authEncodedAsProtocol := fmt.Sprintf("base64url.bearer.authorization.2pass.io.%s", base64.RawURLEncoding.EncodeToString([]byte(auth)))
conn, _, err := wsDialer.Dial(url, http.Header{
"Sec-WebSocket-Protocol": []string{
"2pass.io",
authEncodedAsProtocol,
},
})
if err != nil {
return nil, fmt.Errorf("failed to dial ws %q: %w", url, err)
}
return conn, nil
}
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
// read exactly one message (and then check it is `expectedReadMsg`).
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
conn, err := dialWS(url, token)
if err != nil {
return err
}
defer conn.Close()
doneReading := make(chan error)
go func() {
defer close(doneReading)
_, message, err := conn.ReadMessage()
if err != nil {
doneReading <- fmt.Errorf("faile to read message: %w", err)
return
}
if string(message) != expectedReadMsg {
doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message))
return
}
}()
if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
err, _ = <-doneReading
if err != nil {
return fmt.Errorf("error when reading: %w", err)
}
return nil
}