From f3706182ccb9f7b9065d3195330dfcb9179f5bd9 Mon Sep 17 00:00:00 2001 From: Tobiasz Heller <14020794+tobiaszheller@users.noreply.github.com> Date: Tue, 2 Jan 2024 09:48:34 +0100 Subject: [PATCH] feat: rate limit using redis (#20) --- .env | 4 ++ config/config.go | 29 ++------ go.mod | 3 +- go.sum | 18 +++-- .../app/security/middleware.go | 21 +++--- .../api/browser_extension/service/service.go | 6 +- internal/api/health/ports/http.go | 2 +- internal/api/health/service/service.go | 2 +- .../api/mobile/app/security/middleware.go | 23 +++--- internal/api/mobile/service/service.go | 6 +- .../common/rate_limit/redis_rate_limit.go | 39 +++++----- internal/common/redis/client.go | 3 +- internal/common/security/middleware.go | 20 ++++-- tests/mobile/mobile_security_test.go | 71 +++++++++++++++++-- 14 files changed, 159 insertions(+), 88 deletions(-) diff --git a/.env b/.env index 06f9e4d..7750f31 100644 --- a/.env +++ b/.env @@ -9,3 +9,7 @@ MYSQL_PASSWORD=2fas API_LISTEN_ADDR=:8080 WEBSOCKET_LISTEN_ADDR=:8081 + +SECURITY_RATE_LIMIT_IP=1000 +SECURITY_RATE_LIMIT_BE=100 +SECURITY_RATE_LIMIT_MOBILE=100 diff --git a/config/config.go b/config/config.go index eb9d794..0d345ce 100644 --- a/config/config.go +++ b/config/config.go @@ -1,7 +1,6 @@ package config import ( - "os" "strings" "github.com/spf13/viper" @@ -45,23 +44,9 @@ type AppConfig struct { } type SecurityConfig struct { - TrustedIP []string `mapstructure:"trusted_ip" json:"trusted_ip"` -} - -func (c *SecurityConfig) IsIpTrusted(ip string) bool { - env := os.Getenv("ENV") - - if env == "testing" || env == "development" { - return true - } - - for _, trustedIp := range c.TrustedIP { - if ip == trustedIp { - return true - } - } - - return false + RateLimitIP int `mapstructure:"rate_limit_ip" json:"rate_limit_ip"` + RateLimitMobile int `mapstructure:"rate_limit_mobile" json:"rate_limit_mobile"` + RateLimitBE int `mapstructure:"rate_limit_be" json:"rate_limit_be"` } type WebsocketConfig struct { @@ -111,20 +96,20 @@ func initViper(configFilePath string) { viper.BindEnv("icons.s3_access_key_id", "ICONS_S3_ACCESS_KEY_ID") viper.BindEnv("icons.s3_access_secret_key", "ICONS_S3_ACCESS_SECRET_KEY") + viper.BindEnv("security.rate_limit_ip", "SECURITY_RATE_LIMIT_IP") + viper.BindEnv("security.rate_limit_be", "SECURITY_RATE_LIMIT_BE") + viper.BindEnv("security.rate_limit_mobile", "SECURITY_RATE_LIMIT_MOBILE") + if configFilePath != "" { viper.SetConfigFile(configFilePath) } err := viper.ReadInConfig() - if err != nil { logging.Fatal("failed to read the configuration file: %s", err) } err = viper.Unmarshal(&Config) - - Config.Security.TrustedIP = viper.GetStringSlice("security_trusted_ip") - if err != nil { logging.Fatal("Can not unmarshal configuration", err) } diff --git a/go.mod b/go.mod index ddfcc3b..b83f8da 100644 --- a/go.mod +++ b/go.mod @@ -10,13 +10,14 @@ require ( github.com/gin-contrib/cors v1.4.0 github.com/gin-gonic/gin v1.9.1 github.com/go-playground/validator/v10 v10.15.5 - github.com/go-redis/redis/v8 v8.11.5 + github.com/go-redis/redis_rate/v10 v10.0.1 github.com/go-sql-driver/mysql v1.7.1 github.com/google/uuid v1.3.1 github.com/gorilla/websocket v1.5.0 github.com/jaswdr/faker v1.16.0 github.com/pkg/errors v0.9.1 github.com/pressly/goose/v3 v3.15.1 + github.com/redis/go-redis/v9 v9.3.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/viper v1.17.0 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 395ecb4..520e90b 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,10 @@ github.com/avast/retry-go/v4 v4.5.0 h1:QoRAZZ90cj5oni2Lsgl2GW8mNTnUCnmpx/iKpwVis github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5Aq1fboC3+I= github.com/aws/aws-sdk-go v1.45.25 h1:c4fLlh5sLdK2DCRTY1z0hyuJZU4ygxX8m1FswL6/nF4= github.com/aws/aws-sdk-go v1.45.25/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE= @@ -128,8 +132,8 @@ github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91 github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= github.com/go-playground/validator/v10 v10.15.5 h1:LEBecTWb/1j5TNY1YYG2RcOUN3R7NLylN+x8TTueE24= github.com/go-playground/validator/v10 v10.15.5/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= -github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= -github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= +github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= @@ -289,12 +293,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= -github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= -github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= @@ -308,6 +306,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/pressly/goose/v3 v3.15.1 h1:dKaJ1SdLvS/+HtS8PzFT0KBEtICC1jewLXM+b3emlv8= github.com/pressly/goose/v3 v3.15.1/go.mod h1:0E3Yg/+EwYzO6Rz2P98MlClFgIcoujbVRs575yi3iIM= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0= +github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -720,8 +720,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/internal/api/browser_extension/app/security/middleware.go b/internal/api/browser_extension/app/security/middleware.go index 2f09b07..3b35fee 100644 --- a/internal/api/browser_extension/app/security/middleware.go +++ b/internal/api/browser_extension/app/security/middleware.go @@ -1,17 +1,18 @@ package security import ( - "context" "fmt" + "net/http" + "time" + "github.com/gin-gonic/gin" "github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/rate_limit" - "time" ) -var browserExtensionApiBandwidthAbuseThreshold = 100 +const defaultBrowserExtensionApiBandwidthAbuseThreshold = 100 -func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc { +func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc { return func(c *gin.Context) { extensionId := c.Param("extension_id") @@ -21,12 +22,15 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter key := fmt.Sprintf("security.api.browser_extension.bandwidth.%s", extensionId) + limitValue := rateLimitValue + if limitValue == 0 { + limitValue = defaultBrowserExtensionApiBandwidthAbuseThreshold + } rate := rate_limit.Rate{ TimeUnit: time.Minute, - Limit: browserExtensionApiBandwidthAbuseThreshold, + Limit: limitValue, } - - limitReached := rateLimiter.Test(context.Background(), key, rate) + limitReached := rateLimiter.Test(c, key, rate) if limitReached { logging.WithFields(logging.Fields{ @@ -34,7 +38,8 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter "uri": c.Request.URL.String(), "browser_extension_id": extensionId, "ip": c.ClientIP(), - }).Warning("API potentially abused at Browser Extension scope") + }).Warning("API potentially abused at Browser Extension scope, blocking") + c.AbortWithStatus(http.StatusTooManyRequests) } } } diff --git a/internal/api/browser_extension/service/service.go b/internal/api/browser_extension/service/service.go index 2531040..b6ae8f2 100644 --- a/internal/api/browser_extension/service/service.go +++ b/internal/api/browser_extension/service/service.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/twofas/2fas-server/config" "github.com/twofas/2fas-server/internal/api/browser_extension/adapters" "github.com/twofas/2fas-server/internal/api/browser_extension/app" @@ -125,8 +125,8 @@ func NewBrowserExtensionModule( func (m *BrowserExtensionModule) RegisterPublicRoutes(router *gin.Engine) { rateLimiter := rate_limit.New(m.Redis) - bandwidthAuditMiddleware := apisec.BrowserExtensionBandwidthAuditMiddleware(rateLimiter) - iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter) + bandwidthAuditMiddleware := apisec.BrowserExtensionBandwidthAuditMiddleware(rateLimiter, m.Config.Security.RateLimitBE) + iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitIP) publicRouter := router.Group("/") publicRouter.Use(iPAbuseAuditMiddleware) diff --git a/internal/api/health/ports/http.go b/internal/api/health/ports/http.go index 6c0e074..6ba0dde 100644 --- a/internal/api/health/ports/http.go +++ b/internal/api/health/ports/http.go @@ -7,7 +7,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/twofas/2fas-server/config" mobile "github.com/twofas/2fas-server/internal/api/mobile/domain" support "github.com/twofas/2fas-server/internal/api/support/domain" diff --git a/internal/api/health/service/service.go b/internal/api/health/service/service.go index 53f59c5..c03a6df 100644 --- a/internal/api/health/service/service.go +++ b/internal/api/health/service/service.go @@ -2,7 +2,7 @@ package service import ( "github.com/gin-gonic/gin" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/twofas/2fas-server/config" "github.com/twofas/2fas-server/internal/api/health/ports" ) diff --git a/internal/api/mobile/app/security/middleware.go b/internal/api/mobile/app/security/middleware.go index f0233ec..eda05fd 100644 --- a/internal/api/mobile/app/security/middleware.go +++ b/internal/api/mobile/app/security/middleware.go @@ -1,18 +1,19 @@ package security import ( - "context" "fmt" + "net/http" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/rate_limit" - "strings" - "time" ) -var mobileApiBandwidthAbuseThreshold = 100 +const defaultMobileApiBandwidthAbuseThreshold = 100 -func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc { +func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc { return func(c *gin.Context) { deviceId := c.Param("device_id") extensionId := c.Param("extension_id") @@ -25,13 +26,16 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.Handle fmt.Sprintf("security.api.mobile.bandwidth.%s.%s", deviceId, extensionId), ".", ) - + limitValue := rateLimitValue + if limitValue == 0 { + limitValue = defaultMobileApiBandwidthAbuseThreshold + } rate := rate_limit.Rate{ TimeUnit: time.Minute, - Limit: mobileApiBandwidthAbuseThreshold, + Limit: limitValue, } - limitReached := rateLimiter.Test(context.Background(), key, rate) + limitReached := rateLimiter.Test(c, key, rate) if limitReached { logging.WithFields(logging.Fields{ @@ -40,7 +44,8 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.Handle "device_id": deviceId, "browser_extension_id": extensionId, "ip": c.ClientIP(), - }).Warning("API potentially abused at mobile scope") + }).Warning("API potentially abused at mobile scope, blocking") + c.AbortWithStatus(http.StatusTooManyRequests) } } } diff --git a/internal/api/mobile/service/service.go b/internal/api/mobile/service/service.go index 4c037a6..e6bf5b4 100644 --- a/internal/api/mobile/service/service.go +++ b/internal/api/mobile/service/service.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/twofas/2fas-server/config" browser_extension_adapters "github.com/twofas/2fas-server/internal/api/browser_extension/adapters" "github.com/twofas/2fas-server/internal/api/mobile/adapters" @@ -117,8 +117,8 @@ func NewMobileModule(config config.Configuration, gorm *gorm.DB, database *sql.D func (m *MobileModule) RegisterPublicRoutes(router *gin.Engine) { rateLimiter := rate_limit.New(m.Redis) - bandwidthMobileApiMiddleware := apisec.MobileIpAbuseAuditMiddleware(rateLimiter) - iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter) + bandwidthMobileApiMiddleware := apisec.MobileIpAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitMobile) + iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitIP) publicRouter := router.Group("/") publicRouter.Use(iPAbuseAuditMiddleware) diff --git a/internal/common/rate_limit/redis_rate_limit.go b/internal/common/rate_limit/redis_rate_limit.go index e835ffb..18c2882 100644 --- a/internal/common/rate_limit/redis_rate_limit.go +++ b/internal/common/rate_limit/redis_rate_limit.go @@ -2,8 +2,11 @@ package rate_limit import ( "context" - "github.com/go-redis/redis/v8" "time" + + "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" + "github.com/twofas/2fas-server/internal/common/logging" ) type Rate struct { @@ -18,33 +21,33 @@ type RateLimiter interface { type LimitHandler func() type RedisRateLimit struct { - Client *redis.Client + limiter *redis_rate.Limiter } func New(client *redis.Client) RateLimiter { - return &RedisRateLimit{Client: client} + return &RedisRateLimit{ + limiter: redis_rate.NewLimiter(client), + } } +// Test returns information if limit has been reached. func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool { - counter, err := r.Client.Get(context.Background(), key).Int() - - if err == redis.Nil { - r.Client.Set(ctx, key, 1, rate.TimeUnit) - - return false - } - + res, err := r.limiter.Allow(ctx, key, redis_rate.Limit{ + Rate: rate.Limit, + Burst: rate.Limit, + Period: rate.TimeUnit, + }) if err != nil { + logging.WithFields(logging.Fields{ + "type": "security", + }).Warnf("Could not check rate limit: %v", err) + + // for now we return that limit has not been reached. return false } - - if counter >= rate.Limit { - r.Client.Del(context.Background(), key) - + if res.Allowed <= 0 { + // limit has been reached. return true - } else { - r.Client.Incr(context.Background(), key) } - return false } diff --git a/internal/common/redis/client.go b/internal/common/redis/client.go index 9360c64..e012ec2 100644 --- a/internal/common/redis/client.go +++ b/internal/common/redis/client.go @@ -2,7 +2,8 @@ package redis import ( "fmt" - "github.com/go-redis/redis/v8" + + "github.com/redis/go-redis/v9" ) var ( diff --git a/internal/common/security/middleware.go b/internal/common/security/middleware.go index 067c8cb..f02dc99 100644 --- a/internal/common/security/middleware.go +++ b/internal/common/security/middleware.go @@ -1,35 +1,41 @@ package security import ( - "context" "fmt" + "net/http" + "time" + "github.com/gin-gonic/gin" "github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/common/rate_limit" - "time" ) -var apiBandwidthAbuseThreshold = 100 +const defaultAPIBandwidthAbuseThreshold = 100 -func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc { +func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc { return func(c *gin.Context) { clientIp := c.ClientIP() key := fmt.Sprintf("security.api.ip_bandwidth_audit.%s", clientIp) + limitValue := rateLimitValue + if limitValue == 0 { + limitValue = defaultAPIBandwidthAbuseThreshold + } rate := rate_limit.Rate{ TimeUnit: time.Minute, - Limit: apiBandwidthAbuseThreshold, + Limit: limitValue, } - limitReached := rateLimiter.Test(context.Background(), key, rate) + limitReached := rateLimiter.Test(c, key, rate) if limitReached { logging.WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "ip": c.ClientIP(), - }).Warning("API potentially abused by Client IP") + }).Warning("API potentially abused by Client IP, blocking") + c.AbortWithStatus(http.StatusTooManyRequests) } } } diff --git a/tests/mobile/mobile_security_test.go b/tests/mobile/mobile_security_test.go index cb5787f..689df09 100644 --- a/tests/mobile/mobile_security_test.go +++ b/tests/mobile/mobile_security_test.go @@ -1,24 +1,87 @@ package tests import ( + "net/http" "testing" "github.com/google/uuid" + "github.com/stretchr/testify/require" "github.com/twofas/2fas-server/tests" + "golang.org/x/sync/errgroup" ) func Test_MobileApiBandwidthAbuse(t *testing.T) { someId := uuid.New() - for i := 0; i <= 100; i++ { - tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil) + noOfRequest := 130 + noOfWorkers := 20 + responseCh := make(chan int, noOfRequest) + + eg := errgroup.Group{} + eg.SetLimit(noOfWorkers) + for i := 0; i < noOfRequest; i++ { + eg.Go(func() error { + resp := tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil) + + responseCh <- resp.StatusCode + + return nil + }) } + require.NoError(t, eg.Wait()) + close(responseCh) + + var got404, got429 int + for code := range responseCh { + switch code { + case http.StatusNotFound: + got404++ + case http.StatusTooManyRequests: + got429++ + default: + t.Fatalf("Unexpected code: %v", code) + } + } + // Default rate limit is 100 per minute. + // So we expect around 100 - 404, and around 30 - 429 + require.InDelta(t, 100, got404, 2.0) + require.InDelta(t, 30, got429, 2.0) } func Test_BrowserExtensionApiBandwidthAbuse(t *testing.T) { someId := uuid.New() - for i := 0; i <= 100; i++ { - tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil) + noOfRequest := 130 + noOfWorkers := 20 + responseCh := make(chan int, noOfRequest) + + eg := errgroup.Group{} + eg.SetLimit(noOfWorkers) + for i := 0; i < noOfRequest; i++ { + eg.Go(func() error { + resp := tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil) + + responseCh <- resp.StatusCode + + return nil + }) } + require.NoError(t, eg.Wait()) + close(responseCh) + + var got404, got429 int + for code := range responseCh { + switch code { + case http.StatusNotFound: + got404++ + case http.StatusTooManyRequests: + got429++ + default: + t.Fatalf("Unexpected code: %v", code) + } + } + // Default rate limit is 100 per minute. + // So we expect around 100 - 404, and around 30 - 429 + require.InDelta(t, 100, got404, 2.0) + require.InDelta(t, 30, got429, 2.0) }