use unit tests for proxy

This commit is contained in:
Krzysztof Dryś 2024-06-28 09:37:39 +02:00
parent bcd20584ff
commit bda7098a8b
8 changed files with 284 additions and 99 deletions

View File

@ -2,7 +2,6 @@ package pass
import (
"testing"
"time"
"github.com/google/uuid"
)
@ -17,26 +16,6 @@ func msgOfSize(size int, c byte) string {
return string(msg)
}
func TestDelayedCommunication(t *testing.T) {
resp, err := configureBrowserExtension()
if err != nil {
t.Fatalf("Failed to configure browser extension: %v", err)
}
t.Run("BE sleeps before sending a message", func(t *testing.T) {
deviceID := getDeviceID()
testPairing(t, deviceID, resp, 21*time.Second, 0)
})
t.Run("Mobile sleeps before sending a message", func(t *testing.T) {
deviceID := getDeviceID()
testPairing(t, deviceID, resp, 0, 21*time.Second)
})
t.Run("Both sleep before sending a message", func(t *testing.T) {
deviceID := getDeviceID()
testPairing(t, deviceID, resp, 21*time.Second, 21*time.Second)
})
}
func TestPairHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension()
if err != nil {
@ -44,7 +23,7 @@ func TestPairHappyFlow(t *testing.T) {
}
deviceID := getDeviceID()
testPairing(t, deviceID, resp, 0, 0)
testPairing(t, deviceID, resp)
}
func TestPairMultipleTimes(t *testing.T) {
@ -55,14 +34,14 @@ func TestPairMultipleTimes(t *testing.T) {
deviceID := getDeviceID()
for i := 0; i < 10; i++ {
testPairing(t, deviceID, resp, 0, 0)
testPairing(t, deviceID, resp)
if t.Failed() {
break
}
}
}
func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse, sleepBeforeSendBE, sleepBeforeSendMobile time.Duration) {
func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse) {
t.Helper()
browserExtensionDone := make(chan struct{})
@ -84,7 +63,6 @@ func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionRe
extProxyToken,
msgOfSize(messageSize, 'b'),
msgOfSize(messageSize, 'm'),
sleepBeforeSendBE,
)
if err != nil {
t.Errorf("Browser Extension: proxy failed: %v", err)
@ -106,7 +84,6 @@ func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionRe
mobileProxyToken,
msgOfSize(messageSize, 'm'),
msgOfSize(messageSize, 'b'),
sleepBeforeSendMobile,
)
if err != nil {
t.Errorf("Mobile: proxy failed: %v", err)

View File

@ -2,7 +2,6 @@ package pass
import (
"testing"
"time"
"github.com/google/uuid"
)
@ -53,9 +52,7 @@ func TestSyncHappyFlow(t *testing.T) {
getWsURL()+"/browser_extension/sync/proxy",
proxyToken,
"sent from browser extension",
"sent from mobile",
time.Duration(0),
)
"sent from mobile")
if err != nil {
t.Errorf("Browser Extension: proxy failed: %v", err)
return
@ -87,7 +84,6 @@ func TestSyncHappyFlow(t *testing.T) {
proxyToken,
"sent from mobile",
"sent from browser extension",
time.Duration(0),
)
if err != nil {
t.Errorf("Mobile: proxy failed: %v", err)

View File

@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
"os"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
@ -101,9 +100,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) {
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
// read exactly one message (and then check it is `expectedReadMsg`).
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string, sleepBeforeSend time.Duration) error {
const wsPingFrequency = 5 * time.Second // how often server send pings
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
conn, err := dialWS(url, token)
if err != nil {
return err
@ -111,30 +108,20 @@ func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string, sleepBe
defer conn.Close()
doneReading := make(chan error)
doneWriting := atomic.Bool{}
doneWriting.Store(false)
go func() {
defer close(doneReading)
_, message, err := conn.ReadMessage()
if err != nil {
doneReading <- fmt.Errorf("failed to read message: %w", err)
doneReading <- fmt.Errorf("faile to read message: %w", err)
return
}
if string(message) != expectedReadMsg {
doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message))
}
for !doneWriting.Load() {
conn.SetReadDeadline(time.Now().Add(wsPingFrequency + time.Second))
_, _, err = conn.ReadMessage()
if err != nil {
return
}
return
}
}()
time.Sleep(sleepBeforeSend)
defer doneWriting.Store(true)
if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}

View File

@ -1,4 +1,4 @@
package connection
package proxy
import (
"bytes"
@ -12,19 +12,36 @@ import (
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 20 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = pongWait / 4
DefaultWriteTimeout = 10 * time.Second
DefaultReadTimeout = 20 * time.Second
DefaultPingFrequency = DefaultReadTimeout / 4
DefaultDisconnectAfter = 3 * time.Minute
// Maximum message size allowed from peer.
maxMessageSize = 10 * (2 << 20)
)
type Config struct {
WriteTimeout time.Duration
ReadTimeout time.Duration
PingFrequency time.Duration
DisconnectAfter time.Duration
}
func DefaultConfig() Config {
return Config{
WriteTimeout: DefaultWriteTimeout,
ReadTimeout: DefaultReadTimeout,
PingFrequency: DefaultPingFrequency,
DisconnectAfter: DefaultDisconnectAfter,
}
}
type WriterCloser interface {
Write(msg []byte)
Close()
}
var (
newline = []byte{'\n'}
space = []byte{' '}
@ -37,20 +54,21 @@ var (
}
)
// proxy is a responsible for reading from read chan and sending it over wsConn
// and reading fom wsChan and sending it over send chan
// proxy is a responsible for reading from reader chan and sending it over conn
// and reading fom conn and sending it over writer.
type proxy struct {
send *safeChannel
read chan []byte
conn *websocket.Conn
writer WriterCloser
reader chan []byte
conn *websocket.Conn
cfg Config
}
func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
func Start(wsConn *websocket.Conn, writer WriterCloser, reader chan []byte, cfg Config) {
proxy := &proxy{
send: send,
read: read,
conn: wsConn,
writer: writer,
reader: reader,
conn: wsConn,
cfg: cfg,
}
wg := sync.WaitGroup{}
@ -67,11 +85,10 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
})
go recovery.DoNotPanic(func() {
disconnectAfter := 3 * time.Minute
timeout := time.After(disconnectAfter)
timeout := time.After(cfg.DisconnectAfter)
<-timeout
logging.Info("Connection closed after", disconnectAfter)
logging.Info("Connection closed after", cfg.DisconnectAfter)
proxy.conn.Close()
})
@ -79,7 +96,7 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
wg.Wait()
}
// readPump pumps messages from the websocket proxy to send.
// readPump pumps messages from the websocket proxy to writer.
//
// The application runs readPump in a per-proxy goroutine. The application
// ensures that there is at most one reader on a proxy by executing all
@ -87,13 +104,13 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
func (p *proxy) readPump() {
defer func() {
p.conn.Close()
p.send.close()
p.writer.Close()
}()
p.conn.SetReadLimit(maxMessageSize)
p.conn.SetReadDeadline(time.Now().Add(pongWait))
p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
p.conn.SetPongHandler(func(string) error {
p.conn.SetReadDeadline(time.Now().Add(pongWait))
p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
return nil
})
@ -112,17 +129,17 @@ func (p *proxy) readPump() {
break
}
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
p.send.write(message)
p.writer.Write(message)
}
}
// writePump pumps messages from the read chan to the websocket proxy.
// writePump pumps messages from the reader chan to the websocket proxy.
//
// A goroutine running writePump is started for each proxy. The
// application ensures that there is at most one writer to a proxy by
// executing all writes from this goroutine.
func (p *proxy) writePump() {
ticker := time.NewTicker(pingPeriod)
ticker := time.NewTicker(p.cfg.PingFrequency)
defer func() {
ticker.Stop()
p.conn.Close()
@ -130,8 +147,8 @@ func (p *proxy) writePump() {
for {
select {
case message, ok := <-p.read:
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
case message, ok := <-p.reader:
p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if !ok {
// The hub closed the channel.
p.conn.WriteMessage(websocket.CloseMessage, []byte{})
@ -148,7 +165,7 @@ func (p *proxy) writePump() {
return
}
case <-ticker.C:
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
if err := p.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}

View File

@ -10,13 +10,13 @@ type proxyPool struct {
proxies map[string]*proxyPair
}
// registerMobileConn register proxyPair if not existing in pool and returns it.
func (pp *proxyPool) getOrCreateProxyPair(id string) *proxyPair {
// getOrCreateProxyPair registers proxyPair if not existing in pool and returns it.
func (pp *proxyPool) getOrCreateProxyPair(id string, disconnectAfter time.Duration) *proxyPair {
pp.mu.Lock()
defer pp.mu.Unlock()
v, ok := pp.proxies[id]
if !ok {
v = initProxyPair()
v = initProxyPair(disconnectAfter)
}
pp.proxies[id] = v
return v
@ -48,12 +48,11 @@ type proxyPair struct {
}
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
func initProxyPair() *proxyPair {
const proxyTimeout = 3 * time.Minute
func initProxyPair(disconnectAfter time.Duration) *proxyPair {
return &proxyPair{
toMobileDataCh: newSafeChannel(),
toExtensionDataCh: newSafeChannel(),
expiresAt: time.Now().Add(proxyTimeout),
expiresAt: time.Now().Add(disconnectAfter + time.Minute),
}
}
@ -69,7 +68,7 @@ func newSafeChannel() *safeChannel {
}
}
func (sc *safeChannel) write(data []byte) {
func (sc *safeChannel) Write(data []byte) {
sc.mu.Lock()
defer sc.mu.Unlock()
@ -80,7 +79,7 @@ func (sc *safeChannel) write(data []byte) {
sc.channel <- data
}
func (sc *safeChannel) close() {
func (sc *safeChannel) Close() {
sc.mu.Lock()
defer sc.mu.Unlock()

View File

@ -6,15 +6,17 @@ import (
"time"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
)
// ProxyServer manages proxy connections between Browser Extension and Mobile.
type ProxyServer struct {
proxyPool *proxyPool
idLabel string
proxyPool *proxyPool
idLabel string
proxyConfig proxy.Config
}
func NewProxyServer(idLabel string) *ProxyServer {
func NewProxyServer(idLabel string, proxyConfig proxy.Config) *ProxyServer {
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
go func() {
ticker := time.NewTicker(30 * time.Second)
@ -24,8 +26,9 @@ func NewProxyServer(idLabel string) *ProxyServer {
}
}()
return &ProxyServer{
proxyPool: proxyPool,
idLabel: idLabel,
proxyPool: proxyPool,
idLabel: idLabel,
proxyConfig: proxyConfig,
}
}
@ -38,8 +41,8 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht
log.Infof("Starting ServeExtensionProxyToMobileWS")
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel)
proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter)
proxy.Start(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel, p.proxyConfig)
p.proxyPool.deleteProxyPair(id)
return nil
@ -52,9 +55,9 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht
}
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id)
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter)
startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel)
proxy.Start(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel, p.proxyConfig)
p.proxyPool.deleteProxyPair(id)
return nil

View File

@ -0,0 +1,205 @@
package connection
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
"github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
)
func init() {
logging.Init(nil)
}
// TestProxy sends message both ways and makes sure it is received correctly.
func TestProxy(t *testing.T) {
ws1, ws2, cleanup := setupConnections(t, proxy.DefaultConfig())
defer cleanup()
testWriteReceive(t, ws1, ws2)
testWriteReceive(t, ws2, ws1)
}
// TestConnectionIsClosedAfterTheSpecifiedTime checks that `DisconnectAfter` is obeyed by the proxy server.
func TestConnectionIsClosedAfterTheSpecifiedTime(t *testing.T) {
timeout := time.Second
ws1, ws2, cleanup := setupConnections(t, proxy.Config{
WriteTimeout: proxy.DefaultWriteTimeout,
ReadTimeout: proxy.DefaultReadTimeout,
PingFrequency: proxy.DefaultPingFrequency,
DisconnectAfter: timeout,
})
defer cleanup()
// Exchange some data to make sure the connection is established.
testWriteReceive(t, ws1, ws2)
testWriteReceive(t, ws2, ws1)
// Neither side of the connection sends any message, they just wait on read. Therefore, in both cases ReadMessage
// should exit after the server closes the connection.
ws1Result := make(chan error)
ws2Result := make(chan error)
go func() {
_, _, err := ws1.ReadMessage()
ws1Result <- err
}()
go func() {
_, _, err := ws2.ReadMessage()
ws2Result <- err
}()
// Finish test after timeout and check if connections were closed. One would expect a race condition here
// (we check exactly after timeout) but this test seems to be stable. This is because we have already spent some time
// exchanging the data before waiting for the timeout.
after := time.After(timeout)
var err1, err2 error
done := false
for !done {
select {
case err1 = <-ws1Result:
case err2 = <-ws2Result:
case <-after:
done = true
}
}
if err1 == nil {
t.Logf("WebSocket 1 connection wasn't closed")
}
if err2 == nil {
t.Logf("WebSocket 2 connection wasn't closed")
}
}
// TestPingPongIsEnoughToKeepUsAlive check that the connection is kept alive by the ws native ping-pong mechanism.
// In the Browser Extension the pong response is sent by the browser automatically, in this test framework does it for us
// in ReadMessage.
func TestPingPongIsEnoughToKeepUsAlive(t *testing.T) {
readTimeout := time.Second
ws1, ws2, cleanup := setupConnections(t, proxy.Config{
WriteTimeout: proxy.DefaultWriteTimeout,
ReadTimeout: readTimeout,
PingFrequency: readTimeout / 4,
DisconnectAfter: time.Minute,
})
defer cleanup()
group := errgroup.Group{}
group.Go(func() error {
_, _, err := ws1.ReadMessage()
return err
})
group.Go(func() error {
_, _, err := ws2.ReadMessage()
return err
})
time.Sleep(4 * readTimeout)
// Write some messages to both websockets. This has two benefits:
// 1. It ensures the connections are still alive,
// 2. It makes ReadMessage above return, so group.Wait will exit.
if err := ws1.WriteMessage(websocket.BinaryMessage, []byte("hello!")); err != nil {
t.Errorf("Failed to write message to the first websocket: %v", err)
}
if err := ws2.WriteMessage(websocket.BinaryMessage, []byte("hello!")); err != nil {
t.Errorf("Failed to write message to the second websocket: %v", err)
}
err := group.Wait()
if err != nil {
t.Errorf("Error when reading from websocket: %v", err)
}
}
// setupConnections creates new test websocket server and two connected clients paired in a proxy.
func setupConnections(t *testing.T, cfg proxy.Config) (*websocket.Conn, *websocket.Conn, func()) {
s := httptest.NewServer(testHandler{
t: t,
ps: NewProxyServer("id", cfg),
})
ws1, _, err := testDialer.Dial(makeWsURL(s.URL, "mobile", "1"), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
ws2, _, err := testDialer.Dial(makeWsURL(s.URL, "extension", "1"), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
cleanup := func() {
ws1.Close()
ws2.Close()
s.Close()
}
return ws1, ws2, cleanup
}
var testDialer = websocket.Dialer{
Subprotocols: []string{"2pass.io"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}
// testHandler is for handling http connections.
type testHandler struct {
t *testing.T
ps *ProxyServer
}
func (t testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/mobile" {
t.ps.ServeExtensionProxyToMobileWS(w, r, r.URL.Query().Get("id"))
} else if r.URL.Path == "/extension" {
t.ps.ServeMobileProxyToExtensionWS(w, r, r.URL.Query().Get("id"))
} else {
http.Error(w, "invalid path", http.StatusNotFound)
}
}
// makeWsURL constructs the WebSocket from the test server's URL.
func makeWsURL(s string, app string, id string) string {
return fmt.Sprintf("ws%s/%s?id=%s", strings.TrimPrefix(s, "http"), app, id)
}
// testWriteReceive writes a message to w1 and makes sure it is received by w2.
func testWriteReceive(t *testing.T, ws1, ws2 *websocket.Conn) {
t.Helper()
const message = "Hello, WebSocket!"
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
_, received, err := ws2.ReadMessage()
if err != nil {
t.Errorf("Failed to read message: %v", err)
return
}
if string(received) != message {
t.Errorf("Expected %q, received %q", message, string(received))
}
}()
if err := ws1.WriteMessage(websocket.BinaryMessage, []byte(message)); err != nil {
t.Errorf("Failed to write message: %v", err)
}
wg.Wait()
}

View File

@ -13,6 +13,7 @@ import (
httphelpers "github.com/twofas/2fas-server/internal/common/http"
"github.com/twofas/2fas-server/internal/common/recovery"
"github.com/twofas/2fas-server/internal/pass/connection"
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
"github.com/twofas/2fas-server/internal/pass/fcm"
"github.com/twofas/2fas-server/internal/pass/pairing"
"github.com/twofas/2fas-server/internal/pass/sign"
@ -64,10 +65,10 @@ func NewServer(cfg config.PassConfig) *Server {
}
pairingApp := pairing.NewApp(signSvc, cfg.PairingRequestTokenValidityDuration)
proxyPairingApp := connection.NewProxyServer("device_id")
proxyPairingApp := connection.NewProxyServer("device_id", proxy.DefaultConfig())
syncApp := sync.NewApp(signSvc, fcmClient)
proxySyncApp := connection.NewProxyServer("fcm_token")
proxySyncApp := connection.NewProxyServer("fcm_token", proxy.DefaultConfig())
router := gin.New()
router.Use(recovery.RecoveryMiddleware())