2fas-server/internal/websocket/common/hub_pool_test.go
2024-04-09 09:08:48 +02:00

115 lines
2.8 KiB
Go

package common
import (
"fmt"
"sync"
"testing"
"github.com/gorilla/websocket"
)
func TestRegisterClientDoesNotCreateSameHubTwice(t *testing.T) {
hp := newHubPool()
const channelID = "channelID"
_, h1 := hp.registerClient(channelID, &websocket.Conn{})
_, h2 := hp.registerClient(channelID, &websocket.Conn{})
if h1 != h2 {
t.Fatal("New hub was created")
}
}
func TestRemovingEmptyHub(t *testing.T) {
hp := newHubPool()
const channelID = "channelID"
c, h1 := hp.registerClient(channelID, &websocket.Conn{})
h1.unregisterClient(c)
_, h2 := hp.registerClient(channelID, &websocket.Conn{})
if !h1.isEmpty() {
t.Fatalf("Hub does not report it is empty, even though it should")
}
if h1 == h2 {
t.Fatal("Old hub wasn't deleted")
}
if h2.isEmpty() {
t.Fatal("New heb is empty, even though it shouldn't")
}
}
// TestCreateRemoveConcurrently in which we (for each channel) register a client and then unregister it immediately.
// The last client to be register stays that way.
// We then check:
// - if poll has non-empty hubs,
// - iff all hubs removed from the poll are empty.
func TestCreateRemoveConcurrently(t *testing.T) {
hp := newHubPool()
const channelsNo = 100
const clientsPerChannel = 1000
const messagesSentToEachHub = 100
hubs := &sync.Map{}
wg := sync.WaitGroup{}
// First we create `channelsNo` goroutines. Each of them creates `clientsPerChannel` sub-goroutines.
// This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines.
// Each of them will call `wg.Done() once and we can't progress until all of them are done.
wg.Add(channelsNo*clientsPerChannel + channelsNo)
// We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and
// wait for it to finish.
wg.Add(channelsNo * clientsPerChannel)
for i := 0; i < channelsNo; i++ {
channelID := fmt.Sprintf("channel-%d", i)
c, h := hp.registerClient(channelID, &websocket.Conn{})
hubs.Store(h, struct{}{})
go fakeReadPump(c.send, &wg)
go func() {
for i := 0; i < messagesSentToEachHub; i++ {
h.broadcastMsg([]byte("test"))
}
}()
go func() {
defer wg.Done()
for j := 0; j < clientsPerChannel; j++ {
c, h := hp.registerClient(channelID, &websocket.Conn{})
go fakeReadPump(c.send, &wg)
go func() {
h.unregisterClient(c)
wg.Done()
}()
}
}()
}
wg.Wait()
for c, hub := range hp.hubs {
if hub.isEmpty() {
t.Fatalf("Empty hub found in channel: %q", c)
}
}
hubs.Range(func(key, value any) bool {
h1 := key.(*Hub)
if !h1.isEmpty() {
if h2, ok := hp.hubs[h1.id]; !ok || h1 != h2 {
t.Fatalf("Non-empty hub was evicted from hub pool: %q", h1.id)
}
}
return true
})
}
func fakeReadPump(c chan []byte, wg *sync.WaitGroup) {
defer wg.Done()
for range c {
}
}