diff --git a/internal/pass/connection/proxy.go b/internal/pass/connection/proxy.go index 9e3ac7f..2f1567a 100644 --- a/internal/pass/connection/proxy.go +++ b/internal/pass/connection/proxy.go @@ -2,6 +2,7 @@ package connection import ( "bytes" + "sync" "time" "github.com/gorilla/websocket" @@ -52,11 +53,16 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { conn: wsConn, } + wg := sync.WaitGroup{} + wg.Add(2) + go recovery.DoNotPanic(func() { + defer wg.Done() proxy.writePump() }) go recovery.DoNotPanic(func() { + defer wg.Done() proxy.readPump() }) @@ -69,6 +75,8 @@ func startProxy(wsConn *websocket.Conn, send *safeChannel, read chan []byte) { proxy.conn.Close() }) + + wg.Wait() } // readPump pumps messages from the websocket proxy to send. diff --git a/internal/pass/connection/proxy_pool.go b/internal/pass/connection/proxy_pool.go index 603b3ba..7fa6cfc 100644 --- a/internal/pass/connection/proxy_pool.go +++ b/internal/pass/connection/proxy_pool.go @@ -33,6 +33,14 @@ func (pp *proxyPool) deleteExpiresPairs() { } } +func (pp *proxyPool) deleteProxyPair(id string) { + pp.mu.Lock() + defer pp.mu.Unlock() + + // Channels inside proxyPair are closed in proxy.readPump and proxy.writePump. + delete(pp.proxies, id) +} + type proxyPair struct { toMobileDataCh *safeChannel toExtensionDataCh *safeChannel diff --git a/internal/pass/connection/proxy_server.go b/internal/pass/connection/proxy_server.go index f6d74db..99c3dd4 100644 --- a/internal/pass/connection/proxy_server.go +++ b/internal/pass/connection/proxy_server.go @@ -40,6 +40,8 @@ func (p *ProxyServer) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *ht proxyPair := p.proxyPool.getOrCreateProxyPair(id) startProxy(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh.channel) + + p.proxyPool.deleteProxyPair(id) return nil } @@ -54,5 +56,6 @@ func (p *ProxyServer) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *ht startProxy(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh.channel) + p.proxyPool.deleteProxyPair(id) return nil } diff --git a/tests/pass/http.go b/tests/pass/http.go index ee0e896..2668b40 100644 --- a/tests/pass/http.go +++ b/tests/pass/http.go @@ -57,12 +57,7 @@ func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) { } // confirmMobile confirms pairing and returns mobile proxy token. -func confirmMobile(connectionToken, fcm string) (string, error) { - deviceID := uuid.NewString() - if deviceIDFromEnv := os.Getenv("TEST_DEVICE_ID"); deviceIDFromEnv != "" { - deviceID = deviceIDFromEnv - } - +func confirmMobile(connectionToken, deviceID, fcm string) (string, error) { req := struct { DeviceID string `json:"device_id"` FCMToken string `json:"fcm_token"` diff --git a/tests/pass/lib.go b/tests/pass/lib.go new file mode 100644 index 0000000..ae53ddb --- /dev/null +++ b/tests/pass/lib.go @@ -0,0 +1,15 @@ +package pass + +import ( + "os" + + "github.com/google/uuid" +) + +func getDeviceID() string { + deviceID := uuid.NewString() + if deviceIDFromEnv := os.Getenv("TEST_DEVICE_ID"); deviceIDFromEnv != "" { + deviceID = deviceIDFromEnv + } + return deviceID +} diff --git a/tests/pass/pair_test.go b/tests/pass/pair_test.go index 247e151..01e228f 100644 --- a/tests/pass/pair_test.go +++ b/tests/pass/pair_test.go @@ -22,6 +22,28 @@ func TestPairHappyFlow(t *testing.T) { t.Fatalf("Failed to configure browser extension: %v", err) } + deviceID := getDeviceID() + testPairing(t, deviceID, resp) +} + +func TestPairMultipleTimes(t *testing.T) { + resp, err := configureBrowserExtension() + if err != nil { + t.Fatalf("Failed to configure browser extension: %v", err) + } + + deviceID := getDeviceID() + for i := 0; i < 10; i++ { + testPairing(t, deviceID, resp) + if t.Failed() { + break + } + } +} + +func testPairing(t *testing.T, deviceID string, resp ConfigureBrowserExtensionResponse) { + t.Helper() + browserExtensionDone := make(chan struct{}) mobileDone := make(chan struct{}) @@ -51,7 +73,7 @@ func TestPairHappyFlow(t *testing.T) { go func() { defer close(mobileDone) - mobileProxyToken, err := confirmMobile(resp.ConnectionToken, uuid.NewString()) + mobileProxyToken, err := confirmMobile(resp.ConnectionToken, deviceID, uuid.NewString()) if err != nil { t.Errorf("Mobile: confirm failed: %v", err) return diff --git a/tests/pass/sync_test.go b/tests/pass/sync_test.go index 0a16ef1..9a4e34a 100644 --- a/tests/pass/sync_test.go +++ b/tests/pass/sync_test.go @@ -16,6 +16,7 @@ func TestSyncHappyFlow(t *testing.T) { mobileParingDone := make(chan struct{}) fcm := uuid.NewString() + deviceID := getDeviceID() go func() { defer close(browserExtensionDone) @@ -44,7 +45,7 @@ func TestSyncHappyFlow(t *testing.T) { go func() { defer close(mobileParingDone) - _, err := confirmMobile(resp.ConnectionToken, fcm) + _, err := confirmMobile(resp.ConnectionToken, deviceID, fcm) if err != nil { t.Errorf("Mobile: confirm failed: %v", err) return