mirror of
https://github.com/twofas/2fas-server.git
synced 2025-01-07 06:55:49 +01:00
fix: race condition in hub
This commit is contained in:
parent
7f7ea0693a
commit
2520156c3f
@ -2,9 +2,11 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
)
|
||||
|
||||
@ -43,6 +45,8 @@ type Client struct {
|
||||
|
||||
// Buffered channel of outbound messages.
|
||||
send chan []byte
|
||||
|
||||
sendMtx *sync.Mutex
|
||||
}
|
||||
|
||||
// readPump pumps messages from the websocket connection to the hub.
|
||||
@ -133,3 +137,27 @@ func (c *Client) writePump() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) sendMsg(bb []byte) bool {
|
||||
c.sendMtx.Lock()
|
||||
defer c.sendMtx.Unlock()
|
||||
|
||||
if c.send == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
c.send <- bb
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) close() {
|
||||
c.sendMtx.Lock()
|
||||
defer c.sendMtx.Unlock()
|
||||
|
||||
if c.send == nil {
|
||||
return
|
||||
}
|
||||
|
||||
close(c.send)
|
||||
c.send = nil
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
close(c.send)
|
||||
c.close()
|
||||
if h.isEmpty() {
|
||||
h.onHubHasNoClients(h.id)
|
||||
}
|
||||
@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case c.send <- msg:
|
||||
default:
|
||||
ok = c.sendMsg(msg)
|
||||
if !ok {
|
||||
h.unregisterClient(c)
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client,
|
||||
defer h.mtx.Unlock()
|
||||
|
||||
hub := h.getOrCreateHub(channel)
|
||||
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
|
||||
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}}
|
||||
hub.registerClient(client)
|
||||
|
||||
// handler (caller of this method) isn't really interested in hub,
|
||||
|
@ -50,6 +50,7 @@ func TestCreateRemoveConcurrently(t *testing.T) {
|
||||
hp := newHubPool()
|
||||
const channelsNo = 100
|
||||
const clientsPerChannel = 1000
|
||||
const messagesSentToEachHub = 100
|
||||
|
||||
hubs := &sync.Map{}
|
||||
|
||||
@ -58,21 +59,33 @@ func TestCreateRemoveConcurrently(t *testing.T) {
|
||||
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
|
||||
// Each of them will call `wg.Done() once and we can't progress until all of them are done.
|
||||
wg.Add(channelsNo*clientsPerChannel + channelsNo)
|
||||
// We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and
|
||||
// wait for it to finish.
|
||||
wg.Add(channelsNo * clientsPerChannel)
|
||||
|
||||
for i := 0; i < channelsNo; i++ {
|
||||
channelID := fmt.Sprintf("channel-%d", i)
|
||||
|
||||
c, h := hp.registerClient(channelID, &websocket.Conn{})
|
||||
hubs.Store(h, struct{}{})
|
||||
go fakeReadPump(c.send, &wg)
|
||||
go func() {
|
||||
for i := 0; i < messagesSentToEachHub; i++ {
|
||||
h.broadcastMsg([]byte("test"))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < clientsPerChannel; j++ {
|
||||
c, h := hp.registerClient(channelID, &websocket.Conn{})
|
||||
hubs.Store(h, struct{}{})
|
||||
go fakeReadPump(c.send, &wg)
|
||||
|
||||
go func() {
|
||||
h.unregisterClient(c)
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
_, h := hp.registerClient(channelID, &websocket.Conn{})
|
||||
hubs.Store(h, struct{}{})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
@ -93,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) {
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func fakeReadPump(c chan []byte, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for range c {
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user