fix(pass): increase max message size

Also: fix race condition when closing connection
This commit is contained in:
Krzysztof Dryś 2024-04-09 09:12:32 +02:00
parent 34d87a852a
commit 905606b135
4 changed files with 63 additions and 15 deletions

View File

@ -21,7 +21,7 @@ const (
pingPeriod = (pongWait * 9) / 10 pingPeriod = (pongWait * 9) / 10
// Maximum message size allowed from peer. // Maximum message size allowed from peer.
maxMessageSize = 4 * 1048 maxMessageSize = 10 * (2 << 20)
) )
var ( var (
@ -39,13 +39,13 @@ var (
// proxy is a responsible for reading from read chan and sending it over wsConn // proxy is a responsible for reading from read chan and sending it over wsConn
// and reading fom wsChan and sending it over send chan // and reading fom wsChan and sending it over send chan
type proxy struct { type proxy struct {
send chan []byte send *safeChannel
read chan []byte read chan []byte
conn *websocket.Conn conn *websocket.Conn
} }
func startProxy(wsConn *websocket.Conn, send, read chan []byte) { func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) {
proxy := &proxy{ proxy := &proxy{
send: send, send: send,
read: read, read: read,
@ -79,7 +79,7 @@ func startProxy(wsConn *websocket.Conn, send, read chan []byte) {
func (p *proxy) readPump() { func (p *proxy) readPump() {
defer func() { defer func() {
p.conn.Close() p.conn.Close()
close(p.send) p.send.close()
}() }()
p.conn.SetReadLimit(maxMessageSize) p.conn.SetReadLimit(maxMessageSize)
@ -104,7 +104,7 @@ func (p *proxy) readPump() {
break break
} }
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
p.send <- message p.send.write(message)
} }
} }

View File

@ -34,8 +34,8 @@ func (pp *proxyPool) deleteExpiresPairs() {
} }
type proxyPair struct { type proxyPair struct {
toMobileDataCh chan []byte toMobileDataCh *safeChannel
toExtensionDataCh chan []byte toExtensionDataCh *safeChannel
expiresAt time.Time expiresAt time.Time
} }
@ -43,8 +43,43 @@ type proxyPair struct {
func initProxyPair() *proxyPair { func initProxyPair() *proxyPair {
const proxyTimeout = 3 * time.Minute const proxyTimeout = 3 * time.Minute
return &proxyPair{ return &proxyPair{
toMobileDataCh: make(chan []byte), toMobileDataCh: newSafeChannel(),
toExtensionDataCh: make(chan []byte), toExtensionDataCh: newSafeChannel(),
expiresAt: time.Now().Add(proxyTimeout), expiresAt: time.Now().Add(proxyTimeout),
} }
} }
type safeChannel struct {
channel chan []byte
mu *sync.Mutex
}
func newSafeChannel() *safeChannel {
return &safeChannel{
channel: make(chan []byte),
mu: &sync.Mutex{},
}
}
func (sc *safeChannel) write(data []byte) {
sc.mu.Lock()
defer sc.mu.Unlock()
if sc.channel == nil {
return
}
sc.channel <- data
}
func (sc *safeChannel) close() {
sc.mu.Lock()
defer sc.mu.Unlock()
if sc.channel == nil {
return
}
close(sc.channel)
sc.channel = nil
}

View File

@ -39,7 +39,7 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht
log.Infof("Starting ServeExtensionProxyToMobileWS") log.Infof("Starting ServeExtensionProxyToMobileWS")
proxyPair := p.proxyPool.getOrCreateProxyPair(id) proxyPair := p.proxyPool.getOrCreateProxyPair(id)
startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh) startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel)
return nil return nil
} }
@ -52,7 +52,7 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id) logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", id)
proxyPair := p.proxyPool.getOrCreateProxyPair(id) proxyPair := p.proxyPool.getOrCreateProxyPair(id)
startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh) startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel)
return nil return nil
} }

View File

@ -6,6 +6,16 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
func msgOfSize(size int, c byte) string {
msg := make([]byte, size)
for i := range msg {
msg[i] = c
}
return string(msg)
}
func TestPairHappyFlow(t *testing.T) { func TestPairHappyFlow(t *testing.T) {
resp, err := configureBrowserExtension() resp, err := configureBrowserExtension()
if err != nil { if err != nil {
@ -15,6 +25,8 @@ func TestPairHappyFlow(t *testing.T) {
browserExtensionDone := make(chan struct{}) browserExtensionDone := make(chan struct{})
mobileDone := make(chan struct{}) mobileDone := make(chan struct{})
const messageSize = 1024 * 1024
go func() { go func() {
defer close(browserExtensionDone) defer close(browserExtensionDone)
@ -27,8 +39,9 @@ func TestPairHappyFlow(t *testing.T) {
err = proxyWebSocket( err = proxyWebSocket(
getWsURL()+"/browser_extension/proxy_to_mobile", getWsURL()+"/browser_extension/proxy_to_mobile",
extProxyToken, extProxyToken,
"sent from browser extension", msgOfSize(messageSize, 'b'),
"sent from mobile") msgOfSize(messageSize, 'm'),
)
if err != nil { if err != nil {
t.Errorf("Browser Extension: proxy failed: %v", err) t.Errorf("Browser Extension: proxy failed: %v", err)
return return
@ -47,8 +60,8 @@ func TestPairHappyFlow(t *testing.T) {
err = proxyWebSocket( err = proxyWebSocket(
getWsURL()+"/mobile/proxy_to_browser_extension", getWsURL()+"/mobile/proxy_to_browser_extension",
mobileProxyToken, mobileProxyToken,
"sent from mobile", msgOfSize(messageSize, 'm'),
"sent from browser extension", msgOfSize(messageSize, 'b'),
) )
if err != nil { if err != nil {
t.Errorf("Mobile: proxy failed: %v", err) t.Errorf("Mobile: proxy failed: %v", err)