From e75800c5e39ee6576dd733d62bad1171b4726227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Dry=C5=9B?= Date: Tue, 9 Apr 2024 09:08:48 +0200 Subject: [PATCH] fix: race condition in hub (#39) --- internal/websocket/common/client.go | 28 ++++++++++++++++++++++ internal/websocket/common/hub.go | 7 +++--- internal/websocket/common/hub_pool.go | 2 +- internal/websocket/common/hub_pool_test.go | 25 ++++++++++++++++--- 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/internal/websocket/common/client.go b/internal/websocket/common/client.go index bf35fd7..db1ffe2 100644 --- a/internal/websocket/common/client.go +++ b/internal/websocket/common/client.go @@ -2,9 +2,11 @@ package common import ( "bytes" + "sync" "time" "github.com/gorilla/websocket" + "github.com/twofas/2fas-server/internal/common/logging" ) @@ -43,6 +45,8 @@ type Client struct { // Buffered channel of outbound messages. send chan []byte + + sendMtx *sync.Mutex } // readPump pumps messages from the websocket connection to the hub. @@ -133,3 +137,27 @@ func (c *Client) writePump() { } } } + +func (c *Client) sendMsg(bb []byte) bool { + c.sendMtx.Lock() + defer c.sendMtx.Unlock() + + if c.send == nil { + return false + } + + c.send <- bb + return true +} + +func (c *Client) close() { + c.sendMtx.Lock() + defer c.sendMtx.Unlock() + + if c.send == nil { + return + } + + close(c.send) + c.send = nil +} diff --git a/internal/websocket/common/hub.go b/internal/websocket/common/hub.go index fb4ec31..54eb01a 100644 --- a/internal/websocket/common/hub.go +++ b/internal/websocket/common/hub.go @@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) { if !ok { return } - close(c.send) + c.close() if h.isEmpty() { h.onHubHasNoClients(h.id) } @@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) { if !ok { return } - select { - case c.send <- msg: - default: + ok = c.sendMsg(msg) + if !ok { h.unregisterClient(c) } } diff --git a/internal/websocket/common/hub_pool.go b/internal/websocket/common/hub_pool.go index 10b6958..70b5244 100644 --- a/internal/websocket/common/hub_pool.go +++ b/internal/websocket/common/hub_pool.go @@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client, defer h.mtx.Unlock() hub := h.getOrCreateHub(channel) - client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} + client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}} hub.registerClient(client) // handler (caller of this method) isn't really interested in hub, diff --git a/internal/websocket/common/hub_pool_test.go b/internal/websocket/common/hub_pool_test.go index b340a5e..5a0154b 100644 --- a/internal/websocket/common/hub_pool_test.go +++ b/internal/websocket/common/hub_pool_test.go @@ -50,6 +50,7 @@ func TestCreateRemoveConcurrently(t *testing.T) { hp := newHubPool() const channelsNo = 100 const clientsPerChannel = 1000 + const messagesSentToEachHub = 100 hubs := &sync.Map{} @@ -58,21 +59,33 @@ func TestCreateRemoveConcurrently(t *testing.T) { // This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines. // Each of them will call `wg.Done() once and we can't progress until all of them are done. wg.Add(channelsNo*clientsPerChannel + channelsNo) + // We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and + // wait for it to finish. + wg.Add(channelsNo * clientsPerChannel) for i := 0; i < channelsNo; i++ { channelID := fmt.Sprintf("channel-%d", i) + + c, h := hp.registerClient(channelID, &websocket.Conn{}) + hubs.Store(h, struct{}{}) + go fakeReadPump(c.send, &wg) + go func() { + for i := 0; i < messagesSentToEachHub; i++ { + h.broadcastMsg([]byte("test")) + } + }() + go func() { defer wg.Done() for j := 0; j < clientsPerChannel; j++ { c, h := hp.registerClient(channelID, &websocket.Conn{}) - hubs.Store(h, struct{}{}) + go fakeReadPump(c.send, &wg) + go func() { h.unregisterClient(c) wg.Done() }() } - _, h := hp.registerClient(channelID, &websocket.Conn{}) - hubs.Store(h, struct{}{}) }() } wg.Wait() @@ -93,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) { return true }) } + +func fakeReadPump(c chan []byte, wg *sync.WaitGroup) { + defer wg.Done() + for range c { + } +}