fix: race condition in hub

This commit is contained in:
Krzysztof Dryś 2024-04-08 14:02:16 +02:00
parent 7f7ea0693a
commit 2520156c3f
4 changed files with 54 additions and 8 deletions

View File

@ -2,9 +2,11 @@ package common
import ( import (
"bytes" "bytes"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/logging"
) )
@ -43,6 +45,8 @@ type Client struct {
// Buffered channel of outbound messages. // Buffered channel of outbound messages.
send chan []byte send chan []byte
sendMtx *sync.Mutex
} }
// readPump pumps messages from the websocket connection to the hub. // readPump pumps messages from the websocket connection to the hub.
@ -133,3 +137,27 @@ func (c *Client) writePump() {
} }
} }
} }
func (c *Client) sendMsg(bb []byte) bool {
c.sendMtx.Lock()
defer c.sendMtx.Unlock()
if c.send == nil {
return false
}
c.send <- bb
return true
}
func (c *Client) close() {
c.sendMtx.Lock()
defer c.sendMtx.Unlock()
if c.send == nil {
return
}
close(c.send)
c.send = nil
}

View File

@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) {
if !ok { if !ok {
return return
} }
close(c.send) c.close()
if h.isEmpty() { if h.isEmpty() {
h.onHubHasNoClients(h.id) h.onHubHasNoClients(h.id)
} }
@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) {
if !ok { if !ok {
return return
} }
select { ok = c.sendMsg(msg)
case c.send <- msg: if !ok {
default:
h.unregisterClient(c) h.unregisterClient(c)
} }
} }

View File

@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client,
defer h.mtx.Unlock() defer h.mtx.Unlock()
hub := h.getOrCreateHub(channel) hub := h.getOrCreateHub(channel)
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}}
hub.registerClient(client) hub.registerClient(client)
// handler (caller of this method) isn't really interested in hub, // handler (caller of this method) isn't really interested in hub,

View File

@ -50,6 +50,7 @@ func TestCreateRemoveConcurrently(t *testing.T) {
hp := newHubPool() hp := newHubPool()
const channelsNo = 100 const channelsNo = 100
const clientsPerChannel = 1000 const clientsPerChannel = 1000
const messagesSentToEachHub = 100
hubs := &sync.Map{} hubs := &sync.Map{}
@ -58,21 +59,33 @@ func TestCreateRemoveConcurrently(t *testing.T) {
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines. // This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
// Each of them will call `wg.Done() once and we can't progress until all of them are done. // Each of them will call `wg.Done() once and we can't progress until all of them are done.
wg.Add(channelsNo*clientsPerChannel + channelsNo) wg.Add(channelsNo*clientsPerChannel + channelsNo)
// We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and
// wait for it to finish.
wg.Add(channelsNo * clientsPerChannel)
for i := 0; i < channelsNo; i++ { for i := 0; i < channelsNo; i++ {
channelID := fmt.Sprintf("channel-%d", i) channelID := fmt.Sprintf("channel-%d", i)
c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
go fakeReadPump(c.send, &wg)
go func() {
for i := 0; i < messagesSentToEachHub; i++ {
h.broadcastMsg([]byte("test"))
}
}()
go func() { go func() {
defer wg.Done() defer wg.Done()
for j := 0; j < clientsPerChannel; j++ { for j := 0; j < clientsPerChannel; j++ {
c, h := hp.registerClient(channelID, &websocket.Conn{}) c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{}) go fakeReadPump(c.send, &wg)
go func() { go func() {
h.unregisterClient(c) h.unregisterClient(c)
wg.Done() wg.Done()
}() }()
} }
_, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
}() }()
} }
wg.Wait() wg.Wait()
@ -93,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) {
return true return true
}) })
} }
func fakeReadPump(c chan []byte, wg *sync.WaitGroup) {
defer wg.Done()
for range c {
}
}