fix: race condition in hub

This commit is contained in:
Krzysztof Dryś 2024-04-08 14:02:16 +02:00
parent 7f7ea0693a
commit 2520156c3f
4 changed files with 54 additions and 8 deletions

View File

@ -2,9 +2,11 @@ package common
import (
"bytes"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/twofas/2fas-server/internal/common/logging"
)
@ -43,6 +45,8 @@ type Client struct {
// Buffered channel of outbound messages.
send chan []byte
sendMtx *sync.Mutex
}
// readPump pumps messages from the websocket connection to the hub.
@ -133,3 +137,27 @@ func (c *Client) writePump() {
}
}
}
func (c *Client) sendMsg(bb []byte) bool {
c.sendMtx.Lock()
defer c.sendMtx.Unlock()
if c.send == nil {
return false
}
c.send <- bb
return true
}
func (c *Client) close() {
c.sendMtx.Lock()
defer c.sendMtx.Unlock()
if c.send == nil {
return
}
close(c.send)
c.send = nil
}

View File

@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) {
if !ok {
return
}
close(c.send)
c.close()
if h.isEmpty() {
h.onHubHasNoClients(h.id)
}
@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) {
if !ok {
return
}
select {
case c.send <- msg:
default:
ok = c.sendMsg(msg)
if !ok {
h.unregisterClient(c)
}
}

View File

@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client,
defer h.mtx.Unlock()
hub := h.getOrCreateHub(channel)
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}}
hub.registerClient(client)
// handler (caller of this method) isn't really interested in hub,

View File

@ -50,6 +50,7 @@ func TestCreateRemoveConcurrently(t *testing.T) {
hp := newHubPool()
const channelsNo = 100
const clientsPerChannel = 1000
const messagesSentToEachHub = 100
hubs := &sync.Map{}
@ -58,21 +59,33 @@ func TestCreateRemoveConcurrently(t *testing.T) {
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
// Each of them will call `wg.Done() once and we can't progress until all of them are done.
wg.Add(channelsNo*clientsPerChannel + channelsNo)
// We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and
// wait for it to finish.
wg.Add(channelsNo * clientsPerChannel)
for i := 0; i < channelsNo; i++ {
channelID := fmt.Sprintf("channel-%d", i)
c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
go fakeReadPump(c.send, &wg)
go func() {
for i := 0; i < messagesSentToEachHub; i++ {
h.broadcastMsg([]byte("test"))
}
}()
go func() {
defer wg.Done()
for j := 0; j < clientsPerChannel; j++ {
c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
go fakeReadPump(c.send, &wg)
go func() {
h.unregisterClient(c)
wg.Done()
}()
}
_, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
}()
}
wg.Wait()
@ -93,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) {
return true
})
}
func fakeReadPump(c chan []byte, wg *sync.WaitGroup) {
defer wg.Done()
for range c {
}
}