feat: rate limit using redis (#20)

This commit is contained in:
Tobiasz Heller 2024-01-02 09:48:34 +01:00 committed by GitHub
parent 6b3fcffe1f
commit f3706182cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 159 additions and 88 deletions

4
.env
View File

@ -9,3 +9,7 @@ MYSQL_PASSWORD=2fas
API_LISTEN_ADDR=:8080 API_LISTEN_ADDR=:8080
WEBSOCKET_LISTEN_ADDR=:8081 WEBSOCKET_LISTEN_ADDR=:8081
SECURITY_RATE_LIMIT_IP=1000
SECURITY_RATE_LIMIT_BE=100
SECURITY_RATE_LIMIT_MOBILE=100

View File

@ -1,7 +1,6 @@
package config package config
import ( import (
"os"
"strings" "strings"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -45,23 +44,9 @@ type AppConfig struct {
} }
type SecurityConfig struct { type SecurityConfig struct {
TrustedIP []string `mapstructure:"trusted_ip" json:"trusted_ip"` RateLimitIP int `mapstructure:"rate_limit_ip" json:"rate_limit_ip"`
} RateLimitMobile int `mapstructure:"rate_limit_mobile" json:"rate_limit_mobile"`
RateLimitBE int `mapstructure:"rate_limit_be" json:"rate_limit_be"`
func (c *SecurityConfig) IsIpTrusted(ip string) bool {
env := os.Getenv("ENV")
if env == "testing" || env == "development" {
return true
}
for _, trustedIp := range c.TrustedIP {
if ip == trustedIp {
return true
}
}
return false
} }
type WebsocketConfig struct { type WebsocketConfig struct {
@ -111,20 +96,20 @@ func initViper(configFilePath string) {
viper.BindEnv("icons.s3_access_key_id", "ICONS_S3_ACCESS_KEY_ID") viper.BindEnv("icons.s3_access_key_id", "ICONS_S3_ACCESS_KEY_ID")
viper.BindEnv("icons.s3_access_secret_key", "ICONS_S3_ACCESS_SECRET_KEY") viper.BindEnv("icons.s3_access_secret_key", "ICONS_S3_ACCESS_SECRET_KEY")
viper.BindEnv("security.rate_limit_ip", "SECURITY_RATE_LIMIT_IP")
viper.BindEnv("security.rate_limit_be", "SECURITY_RATE_LIMIT_BE")
viper.BindEnv("security.rate_limit_mobile", "SECURITY_RATE_LIMIT_MOBILE")
if configFilePath != "" { if configFilePath != "" {
viper.SetConfigFile(configFilePath) viper.SetConfigFile(configFilePath)
} }
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
logging.Fatal("failed to read the configuration file: %s", err) logging.Fatal("failed to read the configuration file: %s", err)
} }
err = viper.Unmarshal(&Config) err = viper.Unmarshal(&Config)
Config.Security.TrustedIP = viper.GetStringSlice("security_trusted_ip")
if err != nil { if err != nil {
logging.Fatal("Can not unmarshal configuration", err) logging.Fatal("Can not unmarshal configuration", err)
} }

3
go.mod
View File

@ -10,13 +10,14 @@ require (
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.15.5 github.com/go-playground/validator/v10 v10.15.5
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis_rate/v10 v10.0.1
github.com/go-sql-driver/mysql v1.7.1 github.com/go-sql-driver/mysql v1.7.1
github.com/google/uuid v1.3.1 github.com/google/uuid v1.3.1
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/jaswdr/faker v1.16.0 github.com/jaswdr/faker v1.16.0
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose/v3 v3.15.1 github.com/pressly/goose/v3 v3.15.1
github.com/redis/go-redis/v9 v9.3.0
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/viper v1.17.0 github.com/spf13/viper v1.17.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4

18
go.sum
View File

@ -62,6 +62,10 @@ github.com/avast/retry-go/v4 v4.5.0 h1:QoRAZZ90cj5oni2Lsgl2GW8mNTnUCnmpx/iKpwVis
github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5Aq1fboC3+I= github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5Aq1fboC3+I=
github.com/aws/aws-sdk-go v1.45.25 h1:c4fLlh5sLdK2DCRTY1z0hyuJZU4ygxX8m1FswL6/nF4= github.com/aws/aws-sdk-go v1.45.25 h1:c4fLlh5sLdK2DCRTY1z0hyuJZU4ygxX8m1FswL6/nF4=
github.com/aws/aws-sdk-go v1.45.25/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go v1.45.25/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE= github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE=
@ -128,8 +132,8 @@ github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.15.5 h1:LEBecTWb/1j5TNY1YYG2RcOUN3R7NLylN+x8TTueE24= github.com/go-playground/validator/v10 v10.15.5 h1:LEBecTWb/1j5TNY1YYG2RcOUN3R7NLylN+x8TTueE24=
github.com/go-playground/validator/v10 v10.15.5/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.15.5/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
@ -289,12 +293,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
@ -308,6 +306,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/pressly/goose/v3 v3.15.1 h1:dKaJ1SdLvS/+HtS8PzFT0KBEtICC1jewLXM+b3emlv8= github.com/pressly/goose/v3 v3.15.1 h1:dKaJ1SdLvS/+HtS8PzFT0KBEtICC1jewLXM+b3emlv8=
github.com/pressly/goose/v3 v3.15.1/go.mod h1:0E3Yg/+EwYzO6Rz2P98MlClFgIcoujbVRs575yi3iIM= github.com/pressly/goose/v3 v3.15.1/go.mod h1:0E3Yg/+EwYzO6Rz2P98MlClFgIcoujbVRs575yi3iIM=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -720,8 +720,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

View File

@ -1,17 +1,18 @@
package security package security
import ( import (
"context"
"fmt" "fmt"
"net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/common/rate_limit" "github.com/twofas/2fas-server/internal/common/rate_limit"
"time"
) )
var browserExtensionApiBandwidthAbuseThreshold = 100 const defaultBrowserExtensionApiBandwidthAbuseThreshold = 100
func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc { func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
extensionId := c.Param("extension_id") extensionId := c.Param("extension_id")
@ -21,12 +22,15 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter
key := fmt.Sprintf("security.api.browser_extension.bandwidth.%s", extensionId) key := fmt.Sprintf("security.api.browser_extension.bandwidth.%s", extensionId)
limitValue := rateLimitValue
if limitValue == 0 {
limitValue = defaultBrowserExtensionApiBandwidthAbuseThreshold
}
rate := rate_limit.Rate{ rate := rate_limit.Rate{
TimeUnit: time.Minute, TimeUnit: time.Minute,
Limit: browserExtensionApiBandwidthAbuseThreshold, Limit: limitValue,
} }
limitReached := rateLimiter.Test(c, key, rate)
limitReached := rateLimiter.Test(context.Background(), key, rate)
if limitReached { if limitReached {
logging.WithFields(logging.Fields{ logging.WithFields(logging.Fields{
@ -34,7 +38,8 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter
"uri": c.Request.URL.String(), "uri": c.Request.URL.String(),
"browser_extension_id": extensionId, "browser_extension_id": extensionId,
"ip": c.ClientIP(), "ip": c.ClientIP(),
}).Warning("API potentially abused at Browser Extension scope") }).Warning("API potentially abused at Browser Extension scope, blocking")
c.AbortWithStatus(http.StatusTooManyRequests)
} }
} }
} }

View File

@ -5,7 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/go-redis/redis/v8" "github.com/redis/go-redis/v9"
"github.com/twofas/2fas-server/config" "github.com/twofas/2fas-server/config"
"github.com/twofas/2fas-server/internal/api/browser_extension/adapters" "github.com/twofas/2fas-server/internal/api/browser_extension/adapters"
"github.com/twofas/2fas-server/internal/api/browser_extension/app" "github.com/twofas/2fas-server/internal/api/browser_extension/app"
@ -125,8 +125,8 @@ func NewBrowserExtensionModule(
func (m *BrowserExtensionModule) RegisterPublicRoutes(router *gin.Engine) { func (m *BrowserExtensionModule) RegisterPublicRoutes(router *gin.Engine) {
rateLimiter := rate_limit.New(m.Redis) rateLimiter := rate_limit.New(m.Redis)
bandwidthAuditMiddleware := apisec.BrowserExtensionBandwidthAuditMiddleware(rateLimiter) bandwidthAuditMiddleware := apisec.BrowserExtensionBandwidthAuditMiddleware(rateLimiter, m.Config.Security.RateLimitBE)
iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter) iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitIP)
publicRouter := router.Group("/") publicRouter := router.Group("/")
publicRouter.Use(iPAbuseAuditMiddleware) publicRouter.Use(iPAbuseAuditMiddleware)

View File

@ -7,7 +7,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/redis/go-redis/v9"
"github.com/twofas/2fas-server/config" "github.com/twofas/2fas-server/config"
mobile "github.com/twofas/2fas-server/internal/api/mobile/domain" mobile "github.com/twofas/2fas-server/internal/api/mobile/domain"
support "github.com/twofas/2fas-server/internal/api/support/domain" support "github.com/twofas/2fas-server/internal/api/support/domain"

View File

@ -2,7 +2,7 @@ package service
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/redis/go-redis/v9"
"github.com/twofas/2fas-server/config" "github.com/twofas/2fas-server/config"
"github.com/twofas/2fas-server/internal/api/health/ports" "github.com/twofas/2fas-server/internal/api/health/ports"
) )

View File

@ -1,18 +1,19 @@
package security package security
import ( import (
"context"
"fmt" "fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/common/rate_limit" "github.com/twofas/2fas-server/internal/common/rate_limit"
"strings"
"time"
) )
var mobileApiBandwidthAbuseThreshold = 100 const defaultMobileApiBandwidthAbuseThreshold = 100
func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc { func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
deviceId := c.Param("device_id") deviceId := c.Param("device_id")
extensionId := c.Param("extension_id") extensionId := c.Param("extension_id")
@ -25,13 +26,16 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.Handle
fmt.Sprintf("security.api.mobile.bandwidth.%s.%s", deviceId, extensionId), fmt.Sprintf("security.api.mobile.bandwidth.%s.%s", deviceId, extensionId),
".", ".",
) )
limitValue := rateLimitValue
if limitValue == 0 {
limitValue = defaultMobileApiBandwidthAbuseThreshold
}
rate := rate_limit.Rate{ rate := rate_limit.Rate{
TimeUnit: time.Minute, TimeUnit: time.Minute,
Limit: mobileApiBandwidthAbuseThreshold, Limit: limitValue,
} }
limitReached := rateLimiter.Test(context.Background(), key, rate) limitReached := rateLimiter.Test(c, key, rate)
if limitReached { if limitReached {
logging.WithFields(logging.Fields{ logging.WithFields(logging.Fields{
@ -40,7 +44,8 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.Handle
"device_id": deviceId, "device_id": deviceId,
"browser_extension_id": extensionId, "browser_extension_id": extensionId,
"ip": c.ClientIP(), "ip": c.ClientIP(),
}).Warning("API potentially abused at mobile scope") }).Warning("API potentially abused at mobile scope, blocking")
c.AbortWithStatus(http.StatusTooManyRequests)
} }
} }
} }

View File

@ -5,7 +5,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/go-redis/redis/v8" "github.com/redis/go-redis/v9"
"github.com/twofas/2fas-server/config" "github.com/twofas/2fas-server/config"
browser_extension_adapters "github.com/twofas/2fas-server/internal/api/browser_extension/adapters" browser_extension_adapters "github.com/twofas/2fas-server/internal/api/browser_extension/adapters"
"github.com/twofas/2fas-server/internal/api/mobile/adapters" "github.com/twofas/2fas-server/internal/api/mobile/adapters"
@ -117,8 +117,8 @@ func NewMobileModule(config config.Configuration, gorm *gorm.DB, database *sql.D
func (m *MobileModule) RegisterPublicRoutes(router *gin.Engine) { func (m *MobileModule) RegisterPublicRoutes(router *gin.Engine) {
rateLimiter := rate_limit.New(m.Redis) rateLimiter := rate_limit.New(m.Redis)
bandwidthMobileApiMiddleware := apisec.MobileIpAbuseAuditMiddleware(rateLimiter) bandwidthMobileApiMiddleware := apisec.MobileIpAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitMobile)
iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter) iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitIP)
publicRouter := router.Group("/") publicRouter := router.Group("/")
publicRouter.Use(iPAbuseAuditMiddleware) publicRouter.Use(iPAbuseAuditMiddleware)

View File

@ -2,8 +2,11 @@ package rate_limit
import ( import (
"context" "context"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
"github.com/twofas/2fas-server/internal/common/logging"
) )
type Rate struct { type Rate struct {
@ -18,33 +21,33 @@ type RateLimiter interface {
type LimitHandler func() type LimitHandler func()
type RedisRateLimit struct { type RedisRateLimit struct {
Client *redis.Client limiter *redis_rate.Limiter
} }
func New(client *redis.Client) RateLimiter { func New(client *redis.Client) RateLimiter {
return &RedisRateLimit{Client: client} return &RedisRateLimit{
limiter: redis_rate.NewLimiter(client),
}
} }
// Test returns information if limit has been reached.
func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool { func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool {
counter, err := r.Client.Get(context.Background(), key).Int() res, err := r.limiter.Allow(ctx, key, redis_rate.Limit{
Rate: rate.Limit,
if err == redis.Nil { Burst: rate.Limit,
r.Client.Set(ctx, key, 1, rate.TimeUnit) Period: rate.TimeUnit,
})
return false
}
if err != nil { if err != nil {
logging.WithFields(logging.Fields{
"type": "security",
}).Warnf("Could not check rate limit: %v", err)
// for now we return that limit has not been reached.
return false return false
} }
if res.Allowed <= 0 {
if counter >= rate.Limit { // limit has been reached.
r.Client.Del(context.Background(), key)
return true return true
} else {
r.Client.Incr(context.Background(), key)
} }
return false return false
} }

View File

@ -2,7 +2,8 @@ package redis
import ( import (
"fmt" "fmt"
"github.com/go-redis/redis/v8"
"github.com/redis/go-redis/v9"
) )
var ( var (

View File

@ -1,35 +1,41 @@
package security package security
import ( import (
"context"
"fmt" "fmt"
"net/http"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/logging"
"github.com/twofas/2fas-server/internal/common/rate_limit" "github.com/twofas/2fas-server/internal/common/rate_limit"
"time"
) )
var apiBandwidthAbuseThreshold = 100 const defaultAPIBandwidthAbuseThreshold = 100
func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc { func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
clientIp := c.ClientIP() clientIp := c.ClientIP()
key := fmt.Sprintf("security.api.ip_bandwidth_audit.%s", clientIp) key := fmt.Sprintf("security.api.ip_bandwidth_audit.%s", clientIp)
limitValue := rateLimitValue
if limitValue == 0 {
limitValue = defaultAPIBandwidthAbuseThreshold
}
rate := rate_limit.Rate{ rate := rate_limit.Rate{
TimeUnit: time.Minute, TimeUnit: time.Minute,
Limit: apiBandwidthAbuseThreshold, Limit: limitValue,
} }
limitReached := rateLimiter.Test(context.Background(), key, rate) limitReached := rateLimiter.Test(c, key, rate)
if limitReached { if limitReached {
logging.WithFields(logging.Fields{ logging.WithFields(logging.Fields{
"type": "security", "type": "security",
"uri": c.Request.URL.String(), "uri": c.Request.URL.String(),
"ip": c.ClientIP(), "ip": c.ClientIP(),
}).Warning("API potentially abused by Client IP") }).Warning("API potentially abused by Client IP, blocking")
c.AbortWithStatus(http.StatusTooManyRequests)
} }
} }
} }

View File

@ -1,24 +1,87 @@
package tests package tests
import ( import (
"net/http"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/twofas/2fas-server/tests" "github.com/twofas/2fas-server/tests"
"golang.org/x/sync/errgroup"
) )
func Test_MobileApiBandwidthAbuse(t *testing.T) { func Test_MobileApiBandwidthAbuse(t *testing.T) {
someId := uuid.New() someId := uuid.New()
for i := 0; i <= 100; i++ { noOfRequest := 130
tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil) noOfWorkers := 20
responseCh := make(chan int, noOfRequest)
eg := errgroup.Group{}
eg.SetLimit(noOfWorkers)
for i := 0; i < noOfRequest; i++ {
eg.Go(func() error {
resp := tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil)
responseCh <- resp.StatusCode
return nil
})
} }
require.NoError(t, eg.Wait())
close(responseCh)
var got404, got429 int
for code := range responseCh {
switch code {
case http.StatusNotFound:
got404++
case http.StatusTooManyRequests:
got429++
default:
t.Fatalf("Unexpected code: %v", code)
}
}
// Default rate limit is 100 per minute.
// So we expect around 100 - 404, and around 30 - 429
require.InDelta(t, 100, got404, 2.0)
require.InDelta(t, 30, got429, 2.0)
} }
func Test_BrowserExtensionApiBandwidthAbuse(t *testing.T) { func Test_BrowserExtensionApiBandwidthAbuse(t *testing.T) {
someId := uuid.New() someId := uuid.New()
for i := 0; i <= 100; i++ { noOfRequest := 130
tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil) noOfWorkers := 20
responseCh := make(chan int, noOfRequest)
eg := errgroup.Group{}
eg.SetLimit(noOfWorkers)
for i := 0; i < noOfRequest; i++ {
eg.Go(func() error {
resp := tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil)
responseCh <- resp.StatusCode
return nil
})
} }
require.NoError(t, eg.Wait())
close(responseCh)
var got404, got429 int
for code := range responseCh {
switch code {
case http.StatusNotFound:
got404++
case http.StatusTooManyRequests:
got429++
default:
t.Fatalf("Unexpected code: %v", code)
}
}
// Default rate limit is 100 per minute.
// So we expect around 100 - 404, and around 30 - 429
require.InDelta(t, 100, got404, 2.0)
require.InDelta(t, 30, got429, 2.0)
} }