From 905606b1357579a090a6af9c1a52ad40dbf18101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Dry=C5=9B?= Date: Tue, 9 Apr 2024 09:12:32 +0200 Subject: [PATCH] fix(pass): increase max message size Also: fix race condition when closing connection --- internal/pass/connection/proxy.go | 10 +++--- internal/pass/connection/proxy_pool.go | 43 +++++++++++++++++++++--- internal/pass/connection/proxy_server.go | 4 +-- tests/pass/pair_test.go | 21 +++++++++--- 4 files changed, 63 insertions(+), 15 deletions(-) diff --git a/internal/pass/connection/proxy.go b/internal/pass/connection/proxy.go index 2ca853f..9e3ac7f 100644 --- a/internal/pass/connection/proxy.go +++ b/internal/pass/connection/proxy.go @@ -21,7 +21,7 @@ const ( pingPeriod = (pongWait * 9) / 10 // Maximum message size allowed from peer. - maxMessageSize = 4 * 1048 + maxMessageSize = 10 * (2 << 20) ) var ( @@ -39,13 +39,13 @@ var ( // proxy is a responsible for reading from read chan and sending it over wsConn // and reading fom wsChan and sending it over send chan type proxy struct { - send chan []byte + send *safeChannel read chan []byte conn *websocket.Conn } -func startProxy(wsConn *websocket.Conn, send, read chan []byte) { +func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { proxy := &proxy{ send: send, read: read, @@ -79,7 +79,7 @@ func startProxy(wsConn *websocket.Conn, send, read chan []byte) { func (p *proxy) readPump() { defer func() { p.conn.Close() - close(p.send) + p.send.close() }() p.conn.SetReadLimit(maxMessageSize) @@ -104,7 +104,7 @@ func (p *proxy) readPump() { break } message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) - p.send <- message + p.send.write(message) } } diff --git a/internal/pass/connection/proxy_pool.go b/internal/pass/connection/proxy_pool.go index 1a1c376..603b3ba 100644 --- a/internal/pass/connection/proxy_pool.go +++ b/internal/pass/connection/proxy_pool.go @@ -34,8 +34,8 @@ func (pp *proxyPool) deleteExpiresPairs() { } type proxyPair struct { - toMobileDataCh chan []byte - toExtensionDataCh chan []byte + toMobileDataCh *safeChannel + toExtensionDataCh *safeChannel expiresAt time.Time } @@ -43,8 +43,43 @@ type proxyPair struct { func initProxyPair() *proxyPair { const proxyTimeout = 3 * time.Minute return &proxyPair{ - toMobileDataCh: make(chan []byte), - toExtensionDataCh: make(chan []byte), + toMobileDataCh: newSafeChannel(), + toExtensionDataCh: newSafeChannel(), expiresAt: time.Now().Add(proxyTimeout), } } + +type safeChannel struct { + channel chan []byte + mu *sync.Mutex +} + +func newSafeChannel() *safeChannel { + return &safeChannel{ + channel: make(chan []byte), + mu: &sync.Mutex{}, + } +} + +func (sc *safeChannel) write(data []byte) { + sc.mu.Lock() + defer sc.mu.Unlock() + + if sc.channel == nil { + return + } + + sc.channel <- data +} + +func (sc *safeChannel) close() { + sc.mu.Lock() + defer sc.mu.Unlock() + + if sc.channel == nil { + return + } + + close(sc.channel) + sc.channel = nil +} diff --git a/internal/pass/connection/proxy_server.go b/internal/pass/connection/proxy_server.go index 3a6ec39..f6d74db 100644 --- a/internal/pass/connection/proxy_server.go +++ b/internal/pass/connection/proxy_server.go @@ -39,7 +39,7 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht log.Infof("Starting ServeExtensionProxyToMobileWS") proxyPair := p.proxyPool.getOrCreateProxyPair(id) - startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh) + startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel) return nil } @@ -52,7 +52,7 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id) proxyPair := p.proxyPool.getOrCreateProxyPair(id) - startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh) + startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel) return nil } diff --git a/tests/pass/pair_test.go b/tests/pass/pair_test.go index 67867a1..247e151 100644 --- a/tests/pass/pair_test.go +++ b/tests/pass/pair_test.go @@ -6,6 +6,16 @@ import ( "github.com/google/uuid" ) +func msgOfSize(size int, c byte) string { + msg := make([]byte, size) + + for i := range msg { + msg[i] = c + } + + return string(msg) +} + func TestPairHappyFlow(t *testing.T) { resp, err := configureBrowserExtension() if err != nil { @@ -15,6 +25,8 @@ func TestPairHappyFlow(t *testing.T) { browserExtensionDone := make(chan struct{}) mobileDone := make(chan struct{}) + const messageSize = 1024 * 1024 + go func() { defer close(browserExtensionDone) @@ -27,8 +39,9 @@ func TestPairHappyFlow(t *testing.T) { err = proxyWebSocket( getWsURL()+"/browser_extension/proxy_to_mobile", extProxyToken, - "sent from browser extension", - "sent from mobile") + msgOfSize(messageSize, 'b'), + msgOfSize(messageSize, 'm'), + ) if err != nil { t.Errorf("Browser Extension: proxy failed: %v", err) return @@ -47,8 +60,8 @@ func TestPairHappyFlow(t *testing.T) { err = proxyWebSocket( getWsURL()+"/mobile/proxy_to_browser_extension", mobileProxyToken, - "sent from mobile", - "sent from browser extension", + msgOfSize(messageSize, 'm'), + msgOfSize(messageSize, 'b'), ) if err != nil { t.Errorf("Mobile: proxy failed: %v", err)