diff --git a/e2e-tests/pass/pair_test.go b/e2e-tests/pass/pair_test.go index 3133e3a..01e228f 100644 --- a/e2e-tests/pass/pair_test.go +++ b/e2e-tests/pass/pair_test.go @@ -2,7 +2,6 @@ package pass import ( "testing" - "time" "github.com/google/uuid" ) @@ -17,26 +16,6 @@ func msgOfSize(size int, c byte) string { return string(msg) } -func TestDelayedCommunication(t *testing.T) { - resp, err := configureBrowserExtension() - if err != nil { - t.Fatalf("Failed to configure browser extension: %v", err) - } - - t.Run("BE sleeps before sending a message", func(t *testing.T) { - deviceID := getDeviceID() - testPairing(t, deviceID, resp, 21*time.Second, 0) - }) - t.Run("Mobile sleeps before sending a message", func(t *testing.T) { - deviceID := getDeviceID() - testPairing(t, deviceID, resp, 0, 21*time.Second) - }) - t.Run("Both sleep before sending a message", func(t *testing.T) { - deviceID := getDeviceID() - testPairing(t, deviceID, resp, 21*time.Second, 21*time.Second) - }) -} - func TestPairHappyFlow(t *testing.T) { resp, err := configureBrowserExtension() if err != nil { @@ -44,7 +23,7 @@ func TestPairHappyFlow(t *testing.T) { } deviceID := getDeviceID() - testPairing(t, deviceID, resp, 0, 0) + testPairing(t, deviceID, resp) } func TestPairMultipleTimes(t *testing.T) { @@ -55,14 +34,14 @@ func TestPairMultipleTimes(t *testing.T) { deviceID := getDeviceID() for i := 0; i < 10; i++ { - testPairing(t, deviceID, resp, 0, 0) + testPairing(t, deviceID, resp) if t.Failed() { break } } } -func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse, sleepBeforeSendBE, sleepBeforeSendMobile time.Duration) { +func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse) { t.Helper() browserExtensionDone := make(chan struct{}) @@ -84,7 +63,6 @@ func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionRe extProxyToken, msgOfSize(messageSize, 'b'), msgOfSize(messageSize, 'm'), - sleepBeforeSendBE, ) if err != nil { t.Errorf("Browser Extension: proxy failed: %v", err) @@ -106,7 +84,6 @@ func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionRe mobileProxyToken, msgOfSize(messageSize, 'm'), msgOfSize(messageSize, 'b'), - sleepBeforeSendMobile, ) if err != nil { t.Errorf("Mobile: proxy failed: %v", err) diff --git a/e2e-tests/pass/sync_test.go b/e2e-tests/pass/sync_test.go index 4b8ffd7..13447d1 100644 --- a/e2e-tests/pass/sync_test.go +++ b/e2e-tests/pass/sync_test.go @@ -2,7 +2,6 @@ package pass import ( "testing" - "time" "github.com/google/uuid" ) @@ -53,9 +52,7 @@ func TestSyncHappyFlow(t *testing.T) { getWsURL()+"/browser_extension/sync/proxy", proxyToken, "sent from browser extension", - "sent from mobile", - time.Duration(0), - ) + "sent from mobile") if err != nil { t.Errorf("Browser Extension: proxy failed: %v", err) return @@ -87,7 +84,6 @@ func TestSyncHappyFlow(t *testing.T) { proxyToken, "sent from mobile", "sent from browser extension", - time.Duration(0), ) if err != nil { t.Errorf("Mobile: proxy failed: %v", err) diff --git a/e2e-tests/pass/ws.go b/e2e-tests/pass/ws.go index 6e94024..70b9479 100644 --- a/e2e-tests/pass/ws.go +++ b/e2e-tests/pass/ws.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "os" - "sync/atomic" "time" "github.com/gorilla/websocket" @@ -101,9 +100,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) { // proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and // read exactly one message (and then check it is `expectedReadMsg`). -func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string, sleepBeforeSend time.Duration) error { - const wsPingFrequency = 5 * time.Second // how often server send pings - +func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error { conn, err := dialWS(url, token) if err != nil { return err @@ -111,30 +108,20 @@ func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string, sleepBe defer conn.Close() doneReading := make(chan error) - doneWriting := atomic.Bool{} - doneWriting.Store(false) go func() { defer close(doneReading) _, message, err := conn.ReadMessage() if err != nil { - doneReading <- fmt.Errorf("failed to read message: %w", err) + doneReading <- fmt.Errorf("faile to read message: %w", err) + return } if string(message) != expectedReadMsg { doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message)) - } - for !doneWriting.Load() { - conn.SetReadDeadline(time.Now().Add(wsPingFrequency + time.Second)) - _, _, err = conn.ReadMessage() - if err != nil { - return - } + return } }() - time.Sleep(sleepBeforeSend) - - defer doneWriting.Store(true) if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil { return fmt.Errorf("failed to write message: %w", err) } diff --git a/internal/pass/connection/proxy.go b/internal/pass/connection/proxy/proxy.go similarity index 59% rename from internal/pass/connection/proxy.go rename to internal/pass/connection/proxy/proxy.go index 55141e7..620acda 100644 --- a/internal/pass/connection/proxy.go +++ b/internal/pass/connection/proxy/proxy.go @@ -1,4 +1,4 @@ -package connection +package proxy import ( "bytes" @@ -12,19 +12,36 @@ import ( ) const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 20 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = pongWait / 4 + DefaultWriteTimeout = 10 * time.Second + DefaultReadTimeout = 20 * time.Second + DefaultPingFrequency = DefaultReadTimeout / 4 + DefaultDisconnectAfter = 3 * time.Minute // Maximum message size allowed from peer. maxMessageSize = 10 * (2 << 20) ) +type Config struct { + WriteTimeout time.Duration + ReadTimeout time.Duration + PingFrequency time.Duration + DisconnectAfter time.Duration +} + +func DefaultConfig() Config { + return Config{ + WriteTimeout: DefaultWriteTimeout, + ReadTimeout: DefaultReadTimeout, + PingFrequency: DefaultPingFrequency, + DisconnectAfter: DefaultDisconnectAfter, + } +} + +type WriterCloser interface { + Write(msg []byte) + Close() +} + var ( newline = []byte{'\n'} space = []byte{' '} @@ -37,20 +54,21 @@ var ( } ) -// proxy is a responsible for reading from read chan and sending it over wsConn -// and reading fom wsChan and sending it over send chan +// proxy is a responsible for reading from reader chan and sending it over conn +// and reading fom conn and sending it over writer. type proxy struct { - send *safeChannel - read chan []byte - - conn *websocket.Conn + writer WriterCloser + reader chan []byte + conn *websocket.Conn + cfg Config } -func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { +func Start(wsConn *websocket.Conn, writer WriterCloser, reader chan []byte, cfg Config) { proxy := &proxy{ - send: send, - read: read, - conn: wsConn, + writer: writer, + reader: reader, + conn: wsConn, + cfg: cfg, } wg := sync.WaitGroup{} @@ -67,11 +85,10 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { }) go recovery.DoNotPanic(func() { - disconnectAfter := 3 * time.Minute - timeout := time.After(disconnectAfter) + timeout := time.After(cfg.DisconnectAfter) <-timeout - logging.Info("Connection closed after", disconnectAfter) + logging.Info("Connection closed after", cfg.DisconnectAfter) proxy.conn.Close() }) @@ -79,7 +96,7 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { wg.Wait() } -// readPump pumps messages from the websocket proxy to send. +// readPump pumps messages from the websocket proxy to writer. // // The application runs readPump in a per-proxy goroutine. The application // ensures that there is at most one reader on a proxy by executing all @@ -87,13 +104,13 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { func (p *proxy) readPump() { defer func() { p.conn.Close() - p.send.close() + p.writer.Close() }() p.conn.SetReadLimit(maxMessageSize) - p.conn.SetReadDeadline(time.Now().Add(pongWait)) + p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout)) p.conn.SetPongHandler(func(string) error { - p.conn.SetReadDeadline(time.Now().Add(pongWait)) + p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout)) return nil }) @@ -112,17 +129,17 @@ func (p *proxy) readPump() { break } message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) - p.send.write(message) + p.writer.Write(message) } } -// writePump pumps messages from the read chan to the websocket proxy. +// writePump pumps messages from the reader chan to the websocket proxy. // // A goroutine running writePump is started for each proxy. The // application ensures that there is at most one writer to a proxy by // executing all writes from this goroutine. func (p *proxy) writePump() { - ticker := time.NewTicker(pingPeriod) + ticker := time.NewTicker(p.cfg.PingFrequency) defer func() { ticker.Stop() p.conn.Close() @@ -130,8 +147,8 @@ func (p *proxy) writePump() { for { select { - case message, ok := <-p.read: - p.conn.SetWriteDeadline(time.Now().Add(writeWait)) + case message, ok := <-p.reader: + p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout)) if !ok { // The hub closed the channel. p.conn.WriteMessage(websocket.CloseMessage, []byte{}) @@ -148,7 +165,7 @@ func (p *proxy) writePump() { return } case <-ticker.C: - p.conn.SetWriteDeadline(time.Now().Add(writeWait)) + p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout)) if err := p.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } diff --git a/internal/pass/connection/proxy_pool.go b/internal/pass/connection/proxy_pool.go index 7fa6cfc..677be6e 100644 --- a/internal/pass/connection/proxy_pool.go +++ b/internal/pass/connection/proxy_pool.go @@ -10,13 +10,13 @@ type proxyPool struct { proxies map[string]*proxyPair } -// registerMobileConn register proxyPair if not existing in pool and returns it. -func (pp *proxyPool) getOrCreateProxyPair(id string) *proxyPair { +// getOrCreateProxyPair registers proxyPair if not existing in pool and returns it. +func (pp *proxyPool) getOrCreateProxyPair(id string, disconnectAfter time.Duration) *proxyPair { pp.mu.Lock() defer pp.mu.Unlock() v, ok := pp.proxies[id] if !ok { - v = initProxyPair() + v = initProxyPair(disconnectAfter) } pp.proxies[id] = v return v @@ -48,12 +48,11 @@ type proxyPair struct { } // initProxyPair returns proxyPair and runs loop responsible for proxing data. -func initProxyPair() *proxyPair { - const proxyTimeout = 3 * time.Minute +func initProxyPair(disconnectAfter time.Duration) *proxyPair { return &proxyPair{ toMobileDataCh: newSafeChannel(), toExtensionDataCh: newSafeChannel(), - expiresAt: time.Now().Add(proxyTimeout), + expiresAt: time.Now().Add(disconnectAfter + time.Minute), } } @@ -69,7 +68,7 @@ func newSafeChannel() *safeChannel { } } -func (sc *safeChannel) write(data []byte) { +func (sc *safeChannel) Write(data []byte) { sc.mu.Lock() defer sc.mu.Unlock() @@ -80,7 +79,7 @@ func (sc *safeChannel) write(data []byte) { sc.channel <- data } -func (sc *safeChannel) close() { +func (sc *safeChannel) Close() { sc.mu.Lock() defer sc.mu.Unlock() diff --git a/internal/pass/connection/proxy_server.go b/internal/pass/connection/proxy_server.go index 99c3dd4..0f09bed 100644 --- a/internal/pass/connection/proxy_server.go +++ b/internal/pass/connection/proxy_server.go @@ -6,15 +6,17 @@ import ( "time" "github.com/twofas/2fas-server/internal/common/logging" + "github.com/twofas/2fas-server/internal/pass/connection/proxy" ) // ProxyServer manages proxy connections between Browser Extension and Mobile. type ProxyServer struct { - proxyPool *proxyPool - idLabel string + proxyPool *proxyPool + idLabel string + proxyConfig proxy.Config } -func NewProxyServer(idLabel string) *ProxyServer { +func NewProxyServer(idLabel string, proxyConfig proxy.Config) *ProxyServer { proxyPool := &proxyPool{proxies: map[string]*proxyPair{}} go func() { ticker := time.NewTicker(30 * time.Second) @@ -24,8 +26,9 @@ func NewProxyServer(idLabel string) *ProxyServer { } }() return &ProxyServer{ - proxyPool: proxyPool, - idLabel: idLabel, + proxyPool: proxyPool, + idLabel: idLabel, + proxyConfig: proxyConfig, } } @@ -38,8 +41,8 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht log.Infof("Starting ServeExtensionProxyToMobileWS") - proxyPair := p.proxyPool.getOrCreateProxyPair(id) - startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel) + proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter) + proxy.Start(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel, p.proxyConfig) p.proxyPool.deleteProxyPair(id) return nil @@ -52,9 +55,9 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht } logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id) - proxyPair := p.proxyPool.getOrCreateProxyPair(id) + proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter) - startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel) + proxy.Start(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel, p.proxyConfig) p.proxyPool.deleteProxyPair(id) return nil diff --git a/internal/pass/connection/proxy_server_test.go b/internal/pass/connection/proxy_server_test.go new file mode 100644 index 0000000..2b74302 --- /dev/null +++ b/internal/pass/connection/proxy_server_test.go @@ -0,0 +1,205 @@ +package connection + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" + + "github.com/twofas/2fas-server/internal/common/logging" + "github.com/twofas/2fas-server/internal/pass/connection/proxy" +) + +func init() { + logging.Init(nil) +} + +// TestProxy sends message both ways and makes sure it is received correctly. +func TestProxy(t *testing.T) { + ws1, ws2, cleanup := setupConnections(t, proxy.DefaultConfig()) + defer cleanup() + + testWriteReceive(t, ws1, ws2) + testWriteReceive(t, ws2, ws1) +} + +// TestConnectionIsClosedAfterTheSpecifiedTime checks that `DisconnectAfter` is obeyed by the proxy server. +func TestConnectionIsClosedAfterTheSpecifiedTime(t *testing.T) { + timeout := time.Second + + ws1, ws2, cleanup := setupConnections(t, proxy.Config{ + WriteTimeout: proxy.DefaultWriteTimeout, + ReadTimeout: proxy.DefaultReadTimeout, + PingFrequency: proxy.DefaultPingFrequency, + DisconnectAfter: timeout, + }) + defer cleanup() + + // Exchange some data to make sure the connection is established. + testWriteReceive(t, ws1, ws2) + testWriteReceive(t, ws2, ws1) + + // Neither side of the connection sends any message, they just wait on read. Therefore, in both cases ReadMessage + // should exit after the server closes the connection. + ws1Result := make(chan error) + ws2Result := make(chan error) + go func() { + _, _, err := ws1.ReadMessage() + ws1Result <- err + }() + go func() { + _, _, err := ws2.ReadMessage() + ws2Result <- err + }() + + // Finish test after timeout and check if connections were closed. One would expect a race condition here + // (we check exactly after timeout) but this test seems to be stable. This is because we have already spent some time + // exchanging the data before waiting for the timeout. + after := time.After(timeout) + var err1, err2 error + done := false + for !done { + select { + case err1 = <-ws1Result: + case err2 = <-ws2Result: + case <-after: + done = true + } + } + + if err1 == nil { + t.Logf("WebSocket 1 connection wasn't closed") + } + if err2 == nil { + t.Logf("WebSocket 2 connection wasn't closed") + } +} + +// TestPingPongIsEnoughToKeepUsAlive check that the connection is kept alive by the ws native ping-pong mechanism. +// In the Browser Extension the pong response is sent by the browser automatically, in this test framework does it for us +// in ReadMessage. +func TestPingPongIsEnoughToKeepUsAlive(t *testing.T) { + readTimeout := time.Second + + ws1, ws2, cleanup := setupConnections(t, proxy.Config{ + WriteTimeout: proxy.DefaultWriteTimeout, + ReadTimeout: readTimeout, + PingFrequency: readTimeout / 4, + DisconnectAfter: time.Minute, + }) + defer cleanup() + + group := errgroup.Group{} + group.Go(func() error { + _, _, err := ws1.ReadMessage() + return err + }) + group.Go(func() error { + _, _, err := ws2.ReadMessage() + return err + }) + time.Sleep(4 * readTimeout) + + // Write some messages to both websockets. This has two benefits: + // 1. It ensures the connections are still alive, + // 2. It makes ReadMessage above return, so group.Wait will exit. + if err := ws1.WriteMessage(websocket.BinaryMessage, []byte("hello!")); err != nil { + t.Errorf("Failed to write message to the first websocket: %v", err) + } + if err := ws2.WriteMessage(websocket.BinaryMessage, []byte("hello!")); err != nil { + t.Errorf("Failed to write message to the second websocket: %v", err) + } + + err := group.Wait() + if err != nil { + t.Errorf("Error when reading from websocket: %v", err) + } +} + +// setupConnections creates new test websocket server and two connected clients paired in a proxy. +func setupConnections(t *testing.T, cfg proxy.Config) (*websocket.Conn, *websocket.Conn, func()) { + s := httptest.NewServer(testHandler{ + t: t, + ps: NewProxyServer("id", cfg), + }) + + ws1, _, err := testDialer.Dial(makeWsURL(s.URL, "mobile", "1"), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + ws2, _, err := testDialer.Dial(makeWsURL(s.URL, "extension", "1"), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + cleanup := func() { + ws1.Close() + ws2.Close() + s.Close() + } + + return ws1, ws2, cleanup +} + +var testDialer = websocket.Dialer{ + Subprotocols: []string{"2pass.io"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, +} + +// testHandler is for handling http connections. +type testHandler struct { + t *testing.T + ps *ProxyServer +} + +func (t testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/mobile" { + t.ps.ServeExtensionProxyToMobileWS(w, r, r.URL.Query().Get("id")) + } else if r.URL.Path == "/extension" { + t.ps.ServeMobileProxyToExtensionWS(w, r, r.URL.Query().Get("id")) + } else { + http.Error(w, "invalid path", http.StatusNotFound) + } + +} + +// makeWsURL constructs the WebSocket from the test server's URL. +func makeWsURL(s string, app string, id string) string { + return fmt.Sprintf("ws%s/%s?id=%s", strings.TrimPrefix(s, "http"), app, id) +} + +// testWriteReceive writes a message to w1 and makes sure it is received by w2. +func testWriteReceive(t *testing.T, ws1, ws2 *websocket.Conn) { + t.Helper() + const message = "Hello, WebSocket!" + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + _, received, err := ws2.ReadMessage() + if err != nil { + t.Errorf("Failed to read message: %v", err) + return + } + if string(received) != message { + t.Errorf("Expected %q, received %q", message, string(received)) + } + }() + + if err := ws1.WriteMessage(websocket.BinaryMessage, []byte(message)); err != nil { + t.Errorf("Failed to write message: %v", err) + } + + wg.Wait() +} diff --git a/internal/pass/server.go b/internal/pass/server.go index 16317fe..3a74316 100644 --- a/internal/pass/server.go +++ b/internal/pass/server.go @@ -13,6 +13,7 @@ import ( httphelpers "github.com/twofas/2fas-server/internal/common/http" "github.com/twofas/2fas-server/internal/common/recovery" "github.com/twofas/2fas-server/internal/pass/connection" + "github.com/twofas/2fas-server/internal/pass/connection/proxy" "github.com/twofas/2fas-server/internal/pass/fcm" "github.com/twofas/2fas-server/internal/pass/pairing" "github.com/twofas/2fas-server/internal/pass/sign" @@ -64,10 +65,10 @@ func NewServer(cfg config.PassConfig) *Server { } pairingApp := pairing.NewApp(signSvc, cfg.PairingRequestTokenValidityDuration) - proxyPairingApp := connection.NewProxyServer("device_id") + proxyPairingApp := connection.NewProxyServer("device_id", proxy.DefaultConfig()) syncApp := sync.NewApp(signSvc, fcmClient) - proxySyncApp := connection.NewProxyServer("fcm_token") + proxySyncApp := connection.NewProxyServer("fcm_token", proxy.DefaultConfig()) router := gin.New() router.Use(recovery.RecoveryMiddleware())