mirror of
https://github.com/twofas/2fas-server.git
synced 2025-01-07 06:55:49 +01:00
initial version of pairing and proxy (#23)
* initial version of pairing and proxy * apply review comments and rework proxy * delete expires pairs
This commit is contained in:
parent
1413d107b3
commit
dbd4245b6f
@ -1,35 +0,0 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/http"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
addr string
|
||||
}
|
||||
|
||||
func NewServer(addr string) *Server {
|
||||
router := gin.New()
|
||||
|
||||
router.Use(recovery.RecoveryMiddleware())
|
||||
router.Use(http.RequestIdMiddleware())
|
||||
router.Use(http.CorrelationIdMiddleware())
|
||||
router.Use(http.RequestJsonLogger())
|
||||
|
||||
router.GET("/health", func(context *gin.Context) {
|
||||
context.Status(200)
|
||||
})
|
||||
|
||||
return &Server{
|
||||
router: router,
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
return s.router.Run(s.addr)
|
||||
}
|
39
internal/pass/pairing/auth.go
Normal file
39
internal/pass/pairing/auth.go
Normal file
@ -0,0 +1,39 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// VerifyPairingToken verifies pairing token and returns extension_id
|
||||
func (p *Pairing) VerifyPairingToken(ctx context.Context, pairingToken string) (string, error) {
|
||||
// TODO verify pairing token and take extension from token, this is for debug only.
|
||||
extensionID := pairingToken
|
||||
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||
if !ok {
|
||||
return "", errors.New("extension is not configured")
|
||||
}
|
||||
return extensionID, nil
|
||||
}
|
||||
|
||||
// VerifyProxyToken verifies proxy token and returns extension_id
|
||||
func (p *Pairing) VerifyProxyToken(ctx context.Context, proxyToken string) (string, error) {
|
||||
// TODO verify proxy token and take extension from token, this is for debug only.
|
||||
extensionID := proxyToken
|
||||
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||
if !ok {
|
||||
return "", errors.New("extension is not configured")
|
||||
}
|
||||
return extensionID, nil
|
||||
}
|
||||
|
||||
// VerifyConnectionToken verifies connection token and returns extension_id
|
||||
func (p *Pairing) VerifyConnectionToken(ctx context.Context, connectionToken string) (string, error) {
|
||||
// TODO verify proxy token and take extension from token, this is for debug only.
|
||||
extensionID := connectionToken
|
||||
ok := p.store.ExtensionExists(ctx, extensionID)
|
||||
if !ok {
|
||||
return "", errors.New("extension is not configured")
|
||||
}
|
||||
return extensionID, nil
|
||||
}
|
19
internal/pass/pairing/entities.go
Normal file
19
internal/pass/pairing/entities.go
Normal file
@ -0,0 +1,19 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type MobileDevice struct {
|
||||
DeviceID string
|
||||
FCMToken string
|
||||
}
|
||||
|
||||
type PairingInfo struct {
|
||||
Device MobileDevice
|
||||
PairedAt time.Time
|
||||
}
|
||||
|
||||
func (pi *PairingInfo) IsPaired() bool {
|
||||
return !pi.PairedAt.IsZero()
|
||||
}
|
160
internal/pass/pairing/handlers.go
Normal file
160
internal/pass/pairing/handlers.go
Normal file
@ -0,0 +1,160 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
)
|
||||
|
||||
func BrowserExtensionConfigureHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
var req ConfigureBrowserExtensionRequest
|
||||
if err := gCtx.BindJSON(&req); err != nil {
|
||||
gCtx.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := uuid.Parse(req.ExtensionID); err != nil {
|
||||
gCtx.String(http.StatusBadRequest, "extension_id is not valid uuid")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := pairingApp.ConfigureBrowserExtension(gCtx, req)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to configure: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
gCtx.JSON(http.StatusCreated, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func BrowserExtensionWaitForConnHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
extensionID, err := pairingApp.VerifyPairingToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify pairing token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
pairingApp.ServePairingWS(gCtx.Writer, gCtx.Request, extensionID)
|
||||
}
|
||||
}
|
||||
|
||||
func BrowserExtensionProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
extensionID, err := pairingApp.VerifyProxyToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify proxy token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
pairingInfo, err := pairingApp.GetPairingInfo(gCtx, extensionID)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get pairing info: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !pairingInfo.IsPaired() {
|
||||
gCtx.String(http.StatusForbidden, "Pairing is not yet done")
|
||||
return
|
||||
}
|
||||
proxyApp.ServeExtensionProxyToMobileWS(gCtx.Writer, gCtx.Request, extensionID, pairingInfo.Device.DeviceID)
|
||||
}
|
||||
}
|
||||
|
||||
func MobileConfirmHandler(pairingApp *Pairing) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
extensionID, err := pairingApp.VerifyConnectionToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify connection token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var req ConfirmPairingRequest
|
||||
if err := gCtx.BindJSON(&req); err != nil {
|
||||
gCtx.String(http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := uuid.Parse(req.DeviceID); err != nil {
|
||||
gCtx.String(http.StatusBadRequest, "extension_id is not valid uuid")
|
||||
return
|
||||
}
|
||||
|
||||
if err := pairingApp.ConfirmPairing(gCtx, req, extensionID); err != nil {
|
||||
logging.Errorf("Failed to ConfirmPairing: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MobileProxyHandler(pairingApp *Pairing, proxyApp *Proxy) gin.HandlerFunc {
|
||||
return func(gCtx *gin.Context) {
|
||||
// TODO: consider moving auth to middleware.
|
||||
token, err := tokenFromRequest(gCtx)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to get token from request: %v", err)
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
extensionID, err := pairingApp.VerifyConnectionToken(gCtx, token)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to verify connection token: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
log := logging.WithField("extension_id", extensionID)
|
||||
pairingInfo, err := pairingApp.GetPairingInfo(gCtx, extensionID)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get pairing info: %v", err)
|
||||
gCtx.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !pairingInfo.IsPaired() {
|
||||
gCtx.String(http.StatusForbidden, "Pairing is not yet done")
|
||||
return
|
||||
}
|
||||
proxyApp.ServeMobileProxyToExtensionWS(gCtx.Writer, gCtx.Request, pairingInfo.Device.DeviceID)
|
||||
}
|
||||
}
|
||||
|
||||
func tokenFromRequest(gCtx *gin.Context) (string, error) {
|
||||
tokenHeader := gCtx.GetHeader("Authorization")
|
||||
if tokenHeader == "" {
|
||||
return "", errors.New("missing Authorization header")
|
||||
}
|
||||
splitToken := strings.Split(tokenHeader, "Bearer ")
|
||||
if len(splitToken) != 2 {
|
||||
gCtx.Status(http.StatusForbidden)
|
||||
return "", errors.New("missing 'Bearer: value'")
|
||||
}
|
||||
return splitToken[1], nil
|
||||
}
|
71
internal/pass/pairing/memorystore.go
Normal file
71
internal/pass/pairing/memorystore.go
Normal file
@ -0,0 +1,71 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryStore keeps in memory pairing between extension and mobile.
|
||||
//
|
||||
// TODO: check ttlcache pkg, right now entries are not invalidated.
|
||||
type MemoryStore struct {
|
||||
mu sync.Mutex
|
||||
extensionsMap map[string]Item
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
ExtensionID string
|
||||
Expires time.Time
|
||||
PairingInfo PairingInfo
|
||||
}
|
||||
|
||||
func NewMemoryStore() *MemoryStore {
|
||||
return &MemoryStore{
|
||||
extensionsMap: make(map[string]Item),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) AddExtension(_ context.Context, extensionID string) {
|
||||
s.setItem(extensionID, Item{ExtensionID: extensionID})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ExtensionExists(_ context.Context, extensionID string) bool {
|
||||
_, ok := s.getItem(extensionID)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *MemoryStore) GetPairingInfo(ctx context.Context, extensionID string) (PairingInfo, error) {
|
||||
v, ok := s.getItem(extensionID)
|
||||
if !ok {
|
||||
return PairingInfo{}, errors.New("extension does not exists")
|
||||
}
|
||||
return v.PairingInfo, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error {
|
||||
_, ok := s.getItem(extensionID)
|
||||
if !ok {
|
||||
return errors.New("extension does not exists")
|
||||
}
|
||||
s.setItem(extensionID, Item{
|
||||
ExtensionID: extensionID,
|
||||
Expires: time.Time{},
|
||||
PairingInfo: pi,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStore) setItem(key string, item Item) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.extensionsMap[key] = item
|
||||
}
|
||||
|
||||
func (s *MemoryStore) getItem(key string) (Item, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
v, ok := s.extensionsMap[key]
|
||||
return v, ok
|
||||
}
|
162
internal/pass/pairing/pairing.go
Normal file
162
internal/pass/pairing/pairing.go
Normal file
@ -0,0 +1,162 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
)
|
||||
|
||||
type Pairing struct {
|
||||
store store
|
||||
}
|
||||
|
||||
type store interface {
|
||||
AddExtension(ctx context.Context, extensionID string)
|
||||
ExtensionExists(ctx context.Context, extensionID string) bool
|
||||
GetPairingInfo(ctx context.Context, extensionID string) (PairingInfo, error)
|
||||
SetPairingInfo(ctx context.Context, extensionID string, pi PairingInfo) error
|
||||
}
|
||||
|
||||
func NewPairingApp() *Pairing {
|
||||
return &Pairing{
|
||||
store: NewMemoryStore(),
|
||||
}
|
||||
}
|
||||
|
||||
type ConfigureBrowserExtensionRequest struct {
|
||||
ExtensionID string `json:"extension_id"`
|
||||
}
|
||||
|
||||
type ConfigureBrowserExtensionResponse struct {
|
||||
BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"`
|
||||
ConnectionToken string `json:"connection_token"`
|
||||
}
|
||||
|
||||
func (p *Pairing) ConfigureBrowserExtension(ctx context.Context, req ConfigureBrowserExtensionRequest) (ConfigureBrowserExtensionResponse, error) {
|
||||
p.store.AddExtension(ctx, req.ExtensionID)
|
||||
// TODO: generate connection token and pairing token.
|
||||
connectionToken := req.ExtensionID
|
||||
pairingToken := req.ExtensionID
|
||||
|
||||
return ConfigureBrowserExtensionResponse{
|
||||
ConnectionToken: connectionToken,
|
||||
BrowserExtensionPairingToken: pairingToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ExtensionWaitForConnectionInput struct {
|
||||
ResponseWriter http.ResponseWriter
|
||||
HttpReq *http.Request
|
||||
}
|
||||
|
||||
type WaitForConnectionResponse struct {
|
||||
BrowserExtensionProxyToken string `json:"browser_extension_proxy_token"`
|
||||
Status string `json:"status"`
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
|
||||
func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID string) {
|
||||
log := logging.WithField("extension_id", extID)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to upgrade on ServePairingWS: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
log.Info("Starting pairing WS")
|
||||
|
||||
if deviceID, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
|
||||
if err := p.sendTokenAndCloseConn(extID, deviceID, conn); err != nil {
|
||||
log.Errorf("Failed to send token: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
maxWaitTime = 3 * time.Minute
|
||||
checkIfConnectedInterval = time.Second
|
||||
)
|
||||
maxWaitC := time.After(maxWaitTime)
|
||||
// TODO: consider returning event from store on change.
|
||||
connectedCheckTicker := time.NewTicker(checkIfConnectedInterval)
|
||||
defer connectedCheckTicker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-maxWaitC:
|
||||
log.Info("Closing paring ws after timeout")
|
||||
return
|
||||
case <-connectedCheckTicker.C:
|
||||
if deviceID, pairingDone := p.isExtensionPaired(r.Context(), extID, log); pairingDone {
|
||||
if err := p.sendTokenAndCloseConn(extID, deviceID, conn); err != nil {
|
||||
log.Errorf("Failed to send token: %v", err)
|
||||
return
|
||||
}
|
||||
log.WithField("device_id", deviceID).Infof("Paring ws finished")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logrus.Entry) (string, bool) {
|
||||
pairingInfo, err := p.store.GetPairingInfo(ctx, extID)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get pairing info")
|
||||
return "", false
|
||||
}
|
||||
return pairingInfo.Device.DeviceID, pairingInfo.IsPaired()
|
||||
}
|
||||
|
||||
func (p *Pairing) sendTokenAndCloseConn(extID, deviceID string, conn *websocket.Conn) error {
|
||||
// generate token here
|
||||
if err := conn.WriteJSON(WaitForConnectionResponse{
|
||||
// TODO: replace with real token.
|
||||
BrowserExtensionProxyToken: extID,
|
||||
Status: "ok",
|
||||
DeviceID: deviceID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to write to extension: %v", err)
|
||||
}
|
||||
return conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
}
|
||||
|
||||
// GetPairingInfo returns paired device and information if pairing was done.
|
||||
func (p *Pairing) GetPairingInfo(ctx context.Context, extensionID string) (PairingInfo, error) {
|
||||
return p.store.GetPairingInfo(ctx, extensionID)
|
||||
}
|
||||
|
||||
type ConfirmPairingRequest struct {
|
||||
FCMToken string `json:"fcm_token"`
|
||||
DeviceID string `json:"device_id"`
|
||||
}
|
||||
|
||||
func (p *Pairing) ConfirmPairing(ctx context.Context, req ConfirmPairingRequest, extensionID string) error {
|
||||
return p.store.SetPairingInfo(ctx, extensionID, PairingInfo{
|
||||
Device: MobileDevice{
|
||||
DeviceID: req.DeviceID,
|
||||
FCMToken: req.FCMToken,
|
||||
},
|
||||
PairedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4 * 1024,
|
||||
WriteBufferSize: 4 * 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
allowedOrigin := os.Getenv("WEBSOCKET_ALLOWED_ORIGIN")
|
||||
|
||||
if allowedOrigin != "" {
|
||||
return r.Header.Get("Origin") == allowedOrigin
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
}
|
259
internal/pass/pairing/proxy.go
Normal file
259
internal/pass/pairing/proxy.go
Normal file
@ -0,0 +1,259 @@
|
||||
package pairing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
proxyPool *proxyPool
|
||||
}
|
||||
|
||||
func NewProxy() *Proxy {
|
||||
proxyPool := &proxyPool{proxies: map[string]*proxyPair{}}
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
for {
|
||||
<-ticker.C
|
||||
proxyPool.deleteExpiresPairs()
|
||||
}
|
||||
}()
|
||||
return &Proxy{
|
||||
proxyPool: proxyPool,
|
||||
}
|
||||
}
|
||||
|
||||
type proxyPool struct {
|
||||
mu sync.Mutex
|
||||
proxies map[string]*proxyPair
|
||||
}
|
||||
|
||||
// registerMobileConn register proxyPair if not existing in pool and returns it.
|
||||
func (pp *proxyPool) getOrCreateProxyPair(deviceID string) *proxyPair {
|
||||
// TODO: handle delete.
|
||||
// TODO: right now two connections to the same WS results in race for messages/ decide if we want multiple conn or not.
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
v, ok := pp.proxies[deviceID]
|
||||
if !ok {
|
||||
v = initProxyPair()
|
||||
}
|
||||
pp.proxies[deviceID] = v
|
||||
return v
|
||||
}
|
||||
|
||||
func (pp *proxyPool) deleteExpiresPairs() {
|
||||
pp.mu.Lock()
|
||||
defer pp.mu.Unlock()
|
||||
|
||||
for key, pair := range pp.proxies {
|
||||
if time.Now().After(pair.expiresAt) {
|
||||
delete(pp.proxies, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type proxyPair struct {
|
||||
toMobileDataCh chan []byte
|
||||
toExtensionDataCh chan []byte
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// initProxyPair returns proxyPair and runs loop responsible for proxing data.
|
||||
func initProxyPair() *proxyPair {
|
||||
const proxyTimeout = 3 * time.Minute
|
||||
return &proxyPair{
|
||||
toMobileDataCh: make(chan []byte),
|
||||
toExtensionDataCh: make(chan []byte),
|
||||
expiresAt: time.Now().Add(proxyTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
newline = []byte{'\n'}
|
||||
space = []byte{' '}
|
||||
|
||||
acceptedCloseStatus = []int{
|
||||
websocket.CloseNormalClosure,
|
||||
websocket.CloseGoingAway,
|
||||
websocket.CloseNoStatusReceived,
|
||||
websocket.CloseAbnormalClosure,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
|
||||
// Maximum message size allowed from peer.
|
||||
maxMessageSize = 4 * 1048
|
||||
)
|
||||
|
||||
// client is a responsible for reading from read chan and sending it over wsConn
|
||||
// and reading fom wsChan and sending it over send chan
|
||||
type client struct {
|
||||
send chan []byte
|
||||
read chan []byte
|
||||
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func newClient(wsConn *websocket.Conn, send, read chan []byte) *client {
|
||||
return &client{
|
||||
send: send,
|
||||
read: read,
|
||||
conn: wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
// readPump pumps messages from the websocket connection to send.
|
||||
//
|
||||
// The application runs readPump in a per-connection goroutine. The application
|
||||
// ensures that there is at most one reader on a connection by executing all
|
||||
// reads from this goroutine.
|
||||
func (c *client) readPump() {
|
||||
defer func() {
|
||||
c.conn.Close()
|
||||
close(c.send)
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) {
|
||||
logging.WithFields(logging.Fields{
|
||||
"reason": err.Error(),
|
||||
}).Error("Websocket connection closed unexpected")
|
||||
} else {
|
||||
logging.WithFields(logging.Fields{
|
||||
"reason": err.Error(),
|
||||
}).Info("Connection closed")
|
||||
}
|
||||
break
|
||||
}
|
||||
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
|
||||
c.send <- message
|
||||
}
|
||||
}
|
||||
|
||||
// writePump pumps messages from the read chan to the websocket connection.
|
||||
//
|
||||
// A goroutine running writePump is started for each connection. The
|
||||
// application ensures that there is at most one writer to a connection by
|
||||
// executing all writes from this goroutine.
|
||||
func (c *client) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.read:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
// The hub closed the channel.
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := c.conn.NextWriter(websocket.TextMessage)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
w.Write(message)
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeExtensionProxyToMobileWS(w http.ResponseWriter, r *http.Request, extID, deviceID string) {
|
||||
log := logging.WithField("extension_id", extID).WithField("device_id", deviceID)
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to upgrade on ServeExtensionProxyToMobileWS: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Starting ServeExtensionProxyToMobileWS")
|
||||
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(deviceID)
|
||||
client := newClient(conn, proxyPair.toMobileDataCh, proxyPair.toExtensionDataCh)
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.writePump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.readPump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
disconnectAfter := 3 * time.Minute
|
||||
timeout := time.After(disconnectAfter)
|
||||
|
||||
<-timeout
|
||||
logging.Info("Connection closed after", disconnectAfter)
|
||||
|
||||
client.conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) ServeMobileProxyToExtensionWS(w http.ResponseWriter, r *http.Request, deviceID string) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logging.Errorf("Failed to upgrade on ServeMobileProxyToExtensionWS: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logging.Infof("Starting ServeMobileProxyToExtensionWS for dev: %v", deviceID)
|
||||
proxyPair := p.proxyPool.getOrCreateProxyPair(deviceID)
|
||||
|
||||
client := newClient(conn, proxyPair.toExtensionDataCh, proxyPair.toMobileDataCh)
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.writePump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
client.readPump()
|
||||
})
|
||||
|
||||
go recovery.DoNotPanic(func() {
|
||||
disconnectAfter := 3 * time.Minute
|
||||
timeout := time.After(disconnectAfter)
|
||||
|
||||
<-timeout
|
||||
logging.Info("Connection closed after", disconnectAfter)
|
||||
|
||||
client.conn.Close()
|
||||
})
|
||||
}
|
44
internal/pass/server.go
Normal file
44
internal/pass/server.go
Normal file
@ -0,0 +1,44 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
httphelpers "github.com/twofas/2fas-server/internal/common/http"
|
||||
"github.com/twofas/2fas-server/internal/common/recovery"
|
||||
"github.com/twofas/2fas-server/internal/pass/pairing"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
addr string
|
||||
}
|
||||
|
||||
func NewServer(addr string) *Server {
|
||||
pairingApp := pairing.NewPairingApp()
|
||||
proxyApp := pairing.NewProxy()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(recovery.RecoveryMiddleware())
|
||||
router.Use(httphelpers.RequestIdMiddleware())
|
||||
router.Use(httphelpers.CorrelationIdMiddleware())
|
||||
// TODO: don't log auth headers.
|
||||
router.Use(httphelpers.RequestJsonLogger())
|
||||
|
||||
router.GET("/health", func(context *gin.Context) {
|
||||
context.Status(200)
|
||||
})
|
||||
|
||||
router.POST("/browser_extension/configure", pairing.BrowserExtensionConfigureHandler(pairingApp))
|
||||
router.GET("/browser_extension/wait_for_connection", pairing.BrowserExtensionWaitForConnHandler(pairingApp))
|
||||
router.GET("/browser_extension/proxy_to_mobile", pairing.BrowserExtensionProxyHandler(pairingApp, proxyApp))
|
||||
router.POST("/mobile/confirm", pairing.MobileConfirmHandler(pairingApp))
|
||||
router.GET("/mobile/proxy_to_browser_extension", pairing.MobileProxyHandler(pairingApp, proxyApp))
|
||||
|
||||
return &Server{
|
||||
router: router,
|
||||
addr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
return s.router.Run(s.addr)
|
||||
}
|
Loading…
Reference in New Issue
Block a user