mirror of
https://github.com/twofas/2fas-server.git
synced 2025-01-07 06:55:49 +01:00
use unit tests for proxy
This commit is contained in:
parent
bcd20584ff
commit
bda7098a8b
@ -2,7 +2,6 @@ package pass
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -17,26 +16,6 @@ func msgOfSize(size int, c byte) string {
|
||||
return string(msg)
|
||||
}
|
||||
|
||||
func TestDelayedCommunication(t *testing.T) {
|
||||
resp, err := configureBrowserExtension()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to configure browser extension: %v", err)
|
||||
}
|
||||
|
||||
t.Run("BE sleeps before sending a message", func(t *testing.T) {
|
||||
deviceID := getDeviceID()
|
||||
testPairing(t, deviceID, resp, 21*time.Second, 0)
|
||||
})
|
||||
t.Run("Mobile sleeps before sending a message", func(t *testing.T) {
|
||||
deviceID := getDeviceID()
|
||||
testPairing(t, deviceID, resp, 0, 21*time.Second)
|
||||
})
|
||||
t.Run("Both sleep before sending a message", func(t *testing.T) {
|
||||
deviceID := getDeviceID()
|
||||
testPairing(t, deviceID, resp, 21*time.Second, 21*time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPairHappyFlow(t *testing.T) {
|
||||
resp, err := configureBrowserExtension()
|
||||
if err != nil {
|
||||
@ -44,7 +23,7 @@ func TestPairHappyFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
deviceID := getDeviceID()
|
||||
testPairing(t, deviceID, resp, 0, 0)
|
||||
testPairing(t, deviceID, resp)
|
||||
}
|
||||
|
||||
func TestPairMultipleTimes(t *testing.T) {
|
||||
@ -55,14 +34,14 @@ func TestPairMultipleTimes(t *testing.T) {
|
||||
|
||||
deviceID := getDeviceID()
|
||||
for i := 0; i < 10; i++ {
|
||||
testPairing(t, deviceID, resp, 0, 0)
|
||||
testPairing(t, deviceID, resp)
|
||||
if t.Failed() {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse, sleepBeforeSendBE, sleepBeforeSendMobile time.Duration) {
|
||||
func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse) {
|
||||
t.Helper()
|
||||
|
||||
browserExtensionDone := make(chan struct{})
|
||||
@ -84,7 +63,6 @@ func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionRe
|
||||
extProxyToken,
|
||||
msgOfSize(messageSize, 'b'),
|
||||
msgOfSize(messageSize, 'm'),
|
||||
sleepBeforeSendBE,
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("Browser Extension: proxy failed: %v", err)
|
||||
@ -106,7 +84,6 @@ func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionRe
|
||||
mobileProxyToken,
|
||||
msgOfSize(messageSize, 'm'),
|
||||
msgOfSize(messageSize, 'b'),
|
||||
sleepBeforeSendMobile,
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: proxy failed: %v", err)
|
||||
|
@ -2,7 +2,6 @@ package pass
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@ -53,9 +52,7 @@ func TestSyncHappyFlow(t *testing.T) {
|
||||
getWsURL()+"/browser_extension/sync/proxy",
|
||||
proxyToken,
|
||||
"sent from browser extension",
|
||||
"sent from mobile",
|
||||
time.Duration(0),
|
||||
)
|
||||
"sent from mobile")
|
||||
if err != nil {
|
||||
t.Errorf("Browser Extension: proxy failed: %v", err)
|
||||
return
|
||||
@ -87,7 +84,6 @@ func TestSyncHappyFlow(t *testing.T) {
|
||||
proxyToken,
|
||||
"sent from mobile",
|
||||
"sent from browser extension",
|
||||
time.Duration(0),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("Mobile: proxy failed: %v", err)
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@ -101,9 +100,7 @@ func dialWS(url, auth string) (*websocket.Conn, error) {
|
||||
|
||||
// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and
|
||||
// read exactly one message (and then check it is `expectedReadMsg`).
|
||||
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string, sleepBeforeSend time.Duration) error {
|
||||
const wsPingFrequency = 5 * time.Second // how often server send pings
|
||||
|
||||
func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error {
|
||||
conn, err := dialWS(url, token)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -111,30 +108,20 @@ func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string, sleepBe
|
||||
defer conn.Close()
|
||||
|
||||
doneReading := make(chan error)
|
||||
doneWriting := atomic.Bool{}
|
||||
doneWriting.Store(false)
|
||||
|
||||
go func() {
|
||||
defer close(doneReading)
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
doneReading <- fmt.Errorf("failed to read message: %w", err)
|
||||
doneReading <- fmt.Errorf("faile to read message: %w", err)
|
||||
return
|
||||
}
|
||||
if string(message) != expectedReadMsg {
|
||||
doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message))
|
||||
}
|
||||
for !doneWriting.Load() {
|
||||
conn.SetReadDeadline(time.Now().Add(wsPingFrequency + time.Second))
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(sleepBeforeSend)
|
||||
|
||||
defer doneWriting.Store(true)
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil {
|
||||
return fmt.Errorf("failed to write message: %w", err)
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package connection
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -12,19 +12,36 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 20 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = pongWait / 4
|
||||
DefaultWriteTimeout = 10 * time.Second
|
||||
DefaultReadTimeout = 20 * time.Second
|
||||
DefaultPingFrequency = DefaultReadTimeout / 4
|
||||
DefaultDisconnectAfter = 3 * time.Minute
|
||||
|
||||
// Maximum message size allowed from peer.
|
||||
maxMessageSize = 10 * (2 << 20)
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WriteTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
PingFrequency time.Duration
|
||||
DisconnectAfter time.Duration
|
||||
}
|
||||
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
WriteTimeout: DefaultWriteTimeout,
|
||||
ReadTimeout: DefaultReadTimeout,
|
||||
PingFrequency: DefaultPingFrequency,
|
||||
DisconnectAfter: DefaultDisconnectAfter,
|
||||
}
|
||||
}
|
||||
|
||||
type WriterCloser interface {
|
||||
Write(msg []byte)
|
||||
Close()
|
||||
}
|
||||
|
||||
var (
|
||||
newline = []byte{'\n'}
|
||||
space = []byte{' '}
|
||||
@ -37,20 +54,21 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// proxy is a responsible for reading from read chan and sending it over wsConn
|
||||
// and reading fom wsChan and sending it over send chan
|
||||
// proxy is a responsible for reading from reader chan and sending it over conn
|
||||
// and reading fom conn and sending it over writer.
|
||||
type proxy struct {
|
||||
send *safeChannel
|
||||
read chan []byte
|
||||
|
||||
conn *websocket.Conn
|
||||
writer WriterCloser
|
||||
reader chan []byte
|
||||
conn *websocket.Conn
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
|
||||
func Start(wsConn *websocket.Conn, writer WriterCloser, reader chan []byte, cfg Config) {
|
||||
proxy := &proxy{
|
||||
send: send,
|
||||
read: read,
|
||||
conn: wsConn,
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
conn: wsConn,
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
@ -67,11 +85,10 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
disconnectAfter := 3 * time.Minute
|
||||
timeout := time.After(disconnectAfter)
|
||||
timeout := time.After(cfg.DisconnectAfter)
|
||||
|
||||
<-timeout
|
||||
logging.Info("Connection closed after", disconnectAfter)
|
||||
logging.Info("Connection closed after", cfg.DisconnectAfter)
|
||||
|
||||
proxy.conn.Close()
|
||||
})
|
||||
@ -79,7 +96,7 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// readPump pumps messages from the websocket proxy to send.
|
||||
// readPump pumps messages from the websocket proxy to writer.
|
||||
//
|
||||
// The application runs readPump in a per-proxy goroutine. The application
|
||||
// ensures that there is at most one reader on a proxy by executing all
|
||||
@ -87,13 +104,13 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
|
||||
func (p *proxy) readPump() {
|
||||
defer func() {
|
||||
p.conn.Close()
|
||||
p.send.close()
|
||||
p.writer.Close()
|
||||
}()
|
||||
|
||||
p.conn.SetReadLimit(maxMessageSize)
|
||||
p.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
|
||||
p.conn.SetPongHandler(func(string) error {
|
||||
p.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
p.conn.SetReadDeadline(time.Now().Add(p.cfg.ReadTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
@ -112,17 +129,17 @@ func (p *proxy) readPump() {
|
||||
break
|
||||
}
|
||||
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
|
||||
p.send.write(message)
|
||||
p.writer.Write(message)
|
||||
}
|
||||
}
|
||||
|
||||
// writePump pumps messages from the read chan to the websocket proxy.
|
||||
// writePump pumps messages from the reader chan to the websocket proxy.
|
||||
//
|
||||
// A goroutine running writePump is started for each proxy. The
|
||||
// application ensures that there is at most one writer to a proxy by
|
||||
// executing all writes from this goroutine.
|
||||
func (p *proxy) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
ticker := time.NewTicker(p.cfg.PingFrequency)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
p.conn.Close()
|
||||
@ -130,8 +147,8 @@ func (p *proxy) writePump() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-p.read:
|
||||
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
case message, ok := <-p.reader:
|
||||
p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
|
||||
if !ok {
|
||||
// The hub closed the channel.
|
||||
p.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
@ -148,7 +165,7 @@ func (p *proxy) writePump() {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
p.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
p.conn.SetWriteDeadline(time.Now().Add(p.cfg.WriteTimeout))
|
||||
if err := p.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
@ -10,13 +10,13 @@ type proxyPool struct {
|
||||
proxies map[string]*proxyPair
|
||||
}
|
||||
|
||||
// registerMobileConn register proxyPair if not existing in pool and returns it.
|
||||
func (pp *proxyPool) getOrCreateProxyPair(id string) *proxyPair {
|
||||
// getOrCreateProxyPair registers proxyPair if not existing in pool and returns it.
|
||||
func (pp *proxyPool) getOrCreateProxyPair(id string, disconnectAfter time.Duration) *proxyPair {
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
v, ok := pp.proxies[id]
|
||||
if !ok {
|
||||
v = initProxyPair()
|
||||
v = initProxyPair(disconnectAfter)
|
||||
}
|
||||
pp.proxies[id] = v
|
||||
return v
|
||||
@ -48,12 +48,11 @@ type proxyPair struct {
|
||||
}
|
||||
|
||||
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
|
||||
func initProxyPair() *proxyPair {
|
||||
const proxyTimeout = 3 * time.Minute
|
||||
func initProxyPair(disconnectAfter time.Duration) *proxyPair {
|
||||
return &proxyPair{
|
||||
toMobileDataCh: newSafeChannel(),
|
||||
toExtensionDataCh: newSafeChannel(),
|
||||
expiresAt: time.Now().Add(proxyTimeout),
|
||||
expiresAt: time.Now().Add(disconnectAfter + time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,7 +68,7 @@ func newSafeChannel() *safeChannel {
|
||||
}
|
||||
}
|
||||
|
||||
func (sc *safeChannel) write(data []byte) {
|
||||
func (sc *safeChannel) Write(data []byte) {
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
@ -80,7 +79,7 @@ func (sc *safeChannel) write(data []byte) {
|
||||
sc.channel <- data
|
||||
}
|
||||
|
||||
func (sc *safeChannel) close() {
|
||||
func (sc *safeChannel) Close() {
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
|
@ -6,15 +6,17 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
|
||||
)
|
||||
|
||||
// ProxyServer manages proxy connections between Browser Extension and Mobile.
|
||||
type ProxyServer struct {
|
||||
proxyPool *proxyPool
|
||||
idLabel string
|
||||
proxyPool *proxyPool
|
||||
idLabel string
|
||||
proxyConfig proxy.Config
|
||||
}
|
||||
|
||||
func NewProxyServer(idLabel string) *ProxyServer {
|
||||
func NewProxyServer(idLabel string, proxyConfig proxy.Config) *ProxyServer {
|
||||
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
@ -24,8 +26,9 @@ func NewProxyServer(idLabel string) *ProxyServer {
|
||||
}
|
||||
}()
|
||||
return &ProxyServer{
|
||||
proxyPool: proxyPool,
|
||||
idLabel: idLabel,
|
||||
proxyPool: proxyPool,
|
||||
idLabel: idLabel,
|
||||
proxyConfig: proxyConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,8 +41,8 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht
|
||||
|
||||
log.Infof("Starting ServeExtensionProxyToMobileWS")
|
||||
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
|
||||
startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel)
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter)
|
||||
proxy.Start(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel, p.proxyConfig)
|
||||
|
||||
p.proxyPool.deleteProxyPair(id)
|
||||
return nil
|
||||
@ -52,9 +55,9 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht
|
||||
}
|
||||
|
||||
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id)
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(id)
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(id, p.proxyConfig.DisconnectAfter)
|
||||
|
||||
startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel)
|
||||
proxy.Start(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel, p.proxyConfig)
|
||||
|
||||
p.proxyPool.deleteProxyPair(id)
|
||||
return nil
|
||||
|
205
internal/pass/connection/proxy_server_test.go
Normal file
205
internal/pass/connection/proxy_server_test.go
Normal file
@ -0,0 +1,205 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logging.Init(nil)
|
||||
}
|
||||
|
||||
// TestProxy sends message both ways and makes sure it is received correctly.
|
||||
func TestProxy(t *testing.T) {
|
||||
ws1, ws2, cleanup := setupConnections(t, proxy.DefaultConfig())
|
||||
defer cleanup()
|
||||
|
||||
testWriteReceive(t, ws1, ws2)
|
||||
testWriteReceive(t, ws2, ws1)
|
||||
}
|
||||
|
||||
// TestConnectionIsClosedAfterTheSpecifiedTime checks that `DisconnectAfter` is obeyed by the proxy server.
|
||||
func TestConnectionIsClosedAfterTheSpecifiedTime(t *testing.T) {
|
||||
timeout := time.Second
|
||||
|
||||
ws1, ws2, cleanup := setupConnections(t, proxy.Config{
|
||||
WriteTimeout: proxy.DefaultWriteTimeout,
|
||||
ReadTimeout: proxy.DefaultReadTimeout,
|
||||
PingFrequency: proxy.DefaultPingFrequency,
|
||||
DisconnectAfter: timeout,
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
// Exchange some data to make sure the connection is established.
|
||||
testWriteReceive(t, ws1, ws2)
|
||||
testWriteReceive(t, ws2, ws1)
|
||||
|
||||
// Neither side of the connection sends any message, they just wait on read. Therefore, in both cases ReadMessage
|
||||
// should exit after the server closes the connection.
|
||||
ws1Result := make(chan error)
|
||||
ws2Result := make(chan error)
|
||||
go func() {
|
||||
_, _, err := ws1.ReadMessage()
|
||||
ws1Result <- err
|
||||
}()
|
||||
go func() {
|
||||
_, _, err := ws2.ReadMessage()
|
||||
ws2Result <- err
|
||||
}()
|
||||
|
||||
// Finish test after timeout and check if connections were closed. One would expect a race condition here
|
||||
// (we check exactly after timeout) but this test seems to be stable. This is because we have already spent some time
|
||||
// exchanging the data before waiting for the timeout.
|
||||
after := time.After(timeout)
|
||||
var err1, err2 error
|
||||
done := false
|
||||
for !done {
|
||||
select {
|
||||
case err1 = <-ws1Result:
|
||||
case err2 = <-ws2Result:
|
||||
case <-after:
|
||||
done = true
|
||||
}
|
||||
}
|
||||
|
||||
if err1 == nil {
|
||||
t.Logf("WebSocket 1 connection wasn't closed")
|
||||
}
|
||||
if err2 == nil {
|
||||
t.Logf("WebSocket 2 connection wasn't closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPingPongIsEnoughToKeepUsAlive check that the connection is kept alive by the ws native ping-pong mechanism.
|
||||
// In the Browser Extension the pong response is sent by the browser automatically, in this test framework does it for us
|
||||
// in ReadMessage.
|
||||
func TestPingPongIsEnoughToKeepUsAlive(t *testing.T) {
|
||||
readTimeout := time.Second
|
||||
|
||||
ws1, ws2, cleanup := setupConnections(t, proxy.Config{
|
||||
WriteTimeout: proxy.DefaultWriteTimeout,
|
||||
ReadTimeout: readTimeout,
|
||||
PingFrequency: readTimeout / 4,
|
||||
DisconnectAfter: time.Minute,
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
group := errgroup.Group{}
|
||||
group.Go(func() error {
|
||||
_, _, err := ws1.ReadMessage()
|
||||
return err
|
||||
})
|
||||
group.Go(func() error {
|
||||
_, _, err := ws2.ReadMessage()
|
||||
return err
|
||||
})
|
||||
time.Sleep(4 * readTimeout)
|
||||
|
||||
// Write some messages to both websockets. This has two benefits:
|
||||
// 1. It ensures the connections are still alive,
|
||||
// 2. It makes ReadMessage above return, so group.Wait will exit.
|
||||
if err := ws1.WriteMessage(websocket.BinaryMessage, []byte("hello!")); err != nil {
|
||||
t.Errorf("Failed to write message to the first websocket: %v", err)
|
||||
}
|
||||
if err := ws2.WriteMessage(websocket.BinaryMessage, []byte("hello!")); err != nil {
|
||||
t.Errorf("Failed to write message to the second websocket: %v", err)
|
||||
}
|
||||
|
||||
err := group.Wait()
|
||||
if err != nil {
|
||||
t.Errorf("Error when reading from websocket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// setupConnections creates new test websocket server and two connected clients paired in a proxy.
|
||||
func setupConnections(t *testing.T, cfg proxy.Config) (*websocket.Conn, *websocket.Conn, func()) {
|
||||
s := httptest.NewServer(testHandler{
|
||||
t: t,
|
||||
ps: NewProxyServer("id", cfg),
|
||||
})
|
||||
|
||||
ws1, _, err := testDialer.Dial(makeWsURL(s.URL, "mobile", "1"), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
|
||||
ws2, _, err := testDialer.Dial(makeWsURL(s.URL, "extension", "1"), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
ws1.Close()
|
||||
ws2.Close()
|
||||
s.Close()
|
||||
}
|
||||
|
||||
return ws1, ws2, cleanup
|
||||
}
|
||||
|
||||
var testDialer = websocket.Dialer{
|
||||
Subprotocols: []string{"2pass.io"},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
HandshakeTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// testHandler is for handling http connections.
|
||||
type testHandler struct {
|
||||
t *testing.T
|
||||
ps *ProxyServer
|
||||
}
|
||||
|
||||
func (t testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/mobile" {
|
||||
t.ps.ServeExtensionProxyToMobileWS(w, r, r.URL.Query().Get("id"))
|
||||
} else if r.URL.Path == "/extension" {
|
||||
t.ps.ServeMobileProxyToExtensionWS(w, r, r.URL.Query().Get("id"))
|
||||
} else {
|
||||
http.Error(w, "invalid path", http.StatusNotFound)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// makeWsURL constructs the WebSocket from the test server's URL.
|
||||
func makeWsURL(s string, app string, id string) string {
|
||||
return fmt.Sprintf("ws%s/%s?id=%s", strings.TrimPrefix(s, "http"), app, id)
|
||||
}
|
||||
|
||||
// testWriteReceive writes a message to w1 and makes sure it is received by w2.
|
||||
func testWriteReceive(t *testing.T, ws1, ws2 *websocket.Conn) {
|
||||
t.Helper()
|
||||
const message = "Hello, WebSocket!"
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
_, received, err := ws2.ReadMessage()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read message: %v", err)
|
||||
return
|
||||
}
|
||||
if string(received) != message {
|
||||
t.Errorf("Expected %q, received %q", message, string(received))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := ws1.WriteMessage(websocket.BinaryMessage, []byte(message)); err != nil {
|
||||
t.Errorf("Failed to write message: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
@ -13,6 +13,7 @@ import (
|
||||
httphelpers "github.com/twofas/2fas-server/internal/common/http"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection"
|
||||
"github.com/twofas/2fas-server/internal/pass/connection/proxy"
|
||||
"github.com/twofas/2fas-server/internal/pass/fcm"
|
||||
"github.com/twofas/2fas-server/internal/pass/pairing"
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
@ -64,10 +65,10 @@ func NewServer(cfg config.PassConfig) *Server {
|
||||
}
|
||||
|
||||
pairingApp := pairing.NewApp(signSvc, cfg.PairingRequestTokenValidityDuration)
|
||||
proxyPairingApp := connection.NewProxyServer("device_id")
|
||||
proxyPairingApp := connection.NewProxyServer("device_id", proxy.DefaultConfig())
|
||||
|
||||
syncApp := sync.NewApp(signSvc, fcmClient)
|
||||
proxySyncApp := connection.NewProxyServer("fcm_token")
|
||||
proxySyncApp := connection.NewProxyServer("fcm_token", proxy.DefaultConfig())
|
||||
|
||||
router := gin.New()
|
||||
router.Use(recovery.RecoveryMiddleware())
|
||||
|
Loading…
Reference in New Issue
Block a user