2022-12-31 10:22:38 +01:00
|
|
|
package common
|
|
|
|
|
|
|
|
import (
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/gorilla/websocket"
|
2023-01-30 19:59:42 +01:00
|
|
|
"github.com/twofas/2fas-server/internal/common/logging"
|
2022-12-31 10:22:38 +01:00
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
|
|
ReadBufferSize: 4 * 1024,
|
|
|
|
WriteBufferSize: 4 * 1024,
|
|
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
|
|
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
|
|
|
|
|
|
|
|
if allowedOrigin != "" {
|
|
|
|
return r.Header.Get("Origin") == allowedOrigin
|
|
|
|
}
|
|
|
|
|
|
|
|
return true
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
type ConnectionHandler struct {
|
|
|
|
channels map[string]*Hub
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewConnectionHandler() *ConnectionHandler {
|
|
|
|
channels := make(map[string]*Hub)
|
|
|
|
|
|
|
|
return &ConnectionHandler{
|
|
|
|
channels: channels,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-07 23:34:29 +01:00
|
|
|
func (h *ConnectionHandler) Handler() gin.HandlerFunc {
|
2022-12-31 10:22:38 +01:00
|
|
|
return func(c *gin.Context) {
|
|
|
|
channel := c.Request.URL.Path
|
|
|
|
|
2023-02-01 13:05:12 +01:00
|
|
|
logging.WithDefaultField("channel", channel)
|
|
|
|
logging.WithDefaultField("ip", c.ClientIP())
|
|
|
|
|
|
|
|
logging.Info("New channel subscriber")
|
|
|
|
|
2022-12-31 10:22:38 +01:00
|
|
|
hub := h.getHub(channel)
|
|
|
|
|
|
|
|
h.serveWs(hub, c.Writer, c.Request)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (h *ConnectionHandler) getHub(channel string) *Hub {
|
|
|
|
var hub *Hub
|
|
|
|
|
|
|
|
hub, ok := h.channels[channel]
|
|
|
|
|
|
|
|
if !ok {
|
|
|
|
hub = NewHub()
|
|
|
|
|
|
|
|
go hub.Run()
|
|
|
|
|
|
|
|
h.channels[channel] = hub
|
|
|
|
}
|
|
|
|
|
|
|
|
return hub
|
|
|
|
}
|
|
|
|
|
|
|
|
func (h *ConnectionHandler) serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
|
|
|
|
conn, _ := upgrader.Upgrade(w, r, nil)
|
|
|
|
|
|
|
|
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
|
|
|
|
client.hub.register <- client
|
|
|
|
|
|
|
|
go client.writePump()
|
|
|
|
go client.readPump()
|
|
|
|
|
|
|
|
go func() {
|
2023-03-07 23:50:04 +01:00
|
|
|
disconnectAfter := 3 * time.Minute
|
|
|
|
timeout := time.After(disconnectAfter)
|
2023-02-01 13:05:12 +01:00
|
|
|
|
2023-03-07 23:50:04 +01:00
|
|
|
select {
|
|
|
|
case <-timeout:
|
2023-03-07 23:34:29 +01:00
|
|
|
logging.Info("Connection closed after", disconnectAfter)
|
2023-02-01 13:05:12 +01:00
|
|
|
|
2022-12-31 10:22:38 +01:00
|
|
|
client.hub.unregister <- client
|
|
|
|
client.conn.Close()
|
2023-03-07 23:50:04 +01:00
|
|
|
}
|
2022-12-31 10:22:38 +01:00
|
|
|
}()
|
|
|
|
}
|