mirror of
https://github.com/twofas/2fas-server.git
synced 2024-12-04 16:20:13 +01:00
fix: race condition in hub
This commit is contained in:
parent
7f7ea0693a
commit
2520156c3f
@ -2,9 +2,11 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"github.com/twofas/2fas-server/internal/common/logging"
|
"github.com/twofas/2fas-server/internal/common/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,6 +45,8 @@ type Client struct {
|
|||||||
|
|
||||||
// Buffered channel of outbound messages.
|
// Buffered channel of outbound messages.
|
||||||
send chan []byte
|
send chan []byte
|
||||||
|
|
||||||
|
sendMtx *sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// readPump pumps messages from the websocket connection to the hub.
|
// readPump pumps messages from the websocket connection to the hub.
|
||||||
@ -133,3 +137,27 @@ func (c *Client) writePump() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendMsg(bb []byte) bool {
|
||||||
|
c.sendMtx.Lock()
|
||||||
|
defer c.sendMtx.Unlock()
|
||||||
|
|
||||||
|
if c.send == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.send <- bb
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) close() {
|
||||||
|
c.sendMtx.Lock()
|
||||||
|
defer c.sendMtx.Unlock()
|
||||||
|
|
||||||
|
if c.send == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(c.send)
|
||||||
|
c.send = nil
|
||||||
|
}
|
||||||
|
@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
close(c.send)
|
c.close()
|
||||||
if h.isEmpty() {
|
if h.isEmpty() {
|
||||||
h.onHubHasNoClients(h.id)
|
h.onHubHasNoClients(h.id)
|
||||||
}
|
}
|
||||||
@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
select {
|
ok = c.sendMsg(msg)
|
||||||
case c.send <- msg:
|
if !ok {
|
||||||
default:
|
|
||||||
h.unregisterClient(c)
|
h.unregisterClient(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client,
|
|||||||
defer h.mtx.Unlock()
|
defer h.mtx.Unlock()
|
||||||
|
|
||||||
hub := h.getOrCreateHub(channel)
|
hub := h.getOrCreateHub(channel)
|
||||||
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
|
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}}
|
||||||
hub.registerClient(client)
|
hub.registerClient(client)
|
||||||
|
|
||||||
// handler (caller of this method) isn't really interested in hub,
|
// handler (caller of this method) isn't really interested in hub,
|
||||||
|
@ -50,6 +50,7 @@ func TestCreateRemoveConcurrently(t *testing.T) {
|
|||||||
hp := newHubPool()
|
hp := newHubPool()
|
||||||
const channelsNo = 100
|
const channelsNo = 100
|
||||||
const clientsPerChannel = 1000
|
const clientsPerChannel = 1000
|
||||||
|
const messagesSentToEachHub = 100
|
||||||
|
|
||||||
hubs := &sync.Map{}
|
hubs := &sync.Map{}
|
||||||
|
|
||||||
@ -58,21 +59,33 @@ func TestCreateRemoveConcurrently(t *testing.T) {
|
|||||||
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
|
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
|
||||||
// Each of them will call `wg.Done() once and we can't progress until all of them are done.
|
// Each of them will call `wg.Done() once and we can't progress until all of them are done.
|
||||||
wg.Add(channelsNo*clientsPerChannel + channelsNo)
|
wg.Add(channelsNo*clientsPerChannel + channelsNo)
|
||||||
|
// We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and
|
||||||
|
// wait for it to finish.
|
||||||
|
wg.Add(channelsNo * clientsPerChannel)
|
||||||
|
|
||||||
for i := 0; i < channelsNo; i++ {
|
for i := 0; i < channelsNo; i++ {
|
||||||
channelID := fmt.Sprintf("channel-%d", i)
|
channelID := fmt.Sprintf("channel-%d", i)
|
||||||
|
|
||||||
|
c, h := hp.registerClient(channelID, &websocket.Conn{})
|
||||||
|
hubs.Store(h, struct{}{})
|
||||||
|
go fakeReadPump(c.send, &wg)
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < messagesSentToEachHub; i++ {
|
||||||
|
h.broadcastMsg([]byte("test"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for j := 0; j < clientsPerChannel; j++ {
|
for j := 0; j < clientsPerChannel; j++ {
|
||||||
c, h := hp.registerClient(channelID, &websocket.Conn{})
|
c, h := hp.registerClient(channelID, &websocket.Conn{})
|
||||||
hubs.Store(h, struct{}{})
|
go fakeReadPump(c.send, &wg)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
h.unregisterClient(c)
|
h.unregisterClient(c)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
_, h := hp.registerClient(channelID, &websocket.Conn{})
|
|
||||||
hubs.Store(h, struct{}{})
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@ -93,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) {
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fakeReadPump(c chan []byte, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
for range c {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user