mirror of
https://github.com/twofas/2fas-server.git
synced 2024-12-24 01:50:24 +01:00
feat: rate limit using redis (#20)
This commit is contained in:
parent
6b3fcffe1f
commit
f3706182cc
4
.env
4
.env
@ -9,3 +9,7 @@ MYSQL_PASSWORD=2fas
|
||||
|
||||
API_LISTEN_ADDR=:8080
|
||||
WEBSOCKET_LISTEN_ADDR=:8081
|
||||
|
||||
SECURITY_RATE_LIMIT_IP=1000
|
||||
SECURITY_RATE_LIMIT_BE=100
|
||||
SECURITY_RATE_LIMIT_MOBILE=100
|
||||
|
@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
@ -45,23 +44,9 @@ type AppConfig struct {
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
TrustedIP []string `mapstructure:"trusted_ip" json:"trusted_ip"`
|
||||
}
|
||||
|
||||
func (c *SecurityConfig) IsIpTrusted(ip string) bool {
|
||||
env := os.Getenv("ENV")
|
||||
|
||||
if env == "testing" || env == "development" {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, trustedIp := range c.TrustedIP {
|
||||
if ip == trustedIp {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
RateLimitIP int `mapstructure:"rate_limit_ip" json:"rate_limit_ip"`
|
||||
RateLimitMobile int `mapstructure:"rate_limit_mobile" json:"rate_limit_mobile"`
|
||||
RateLimitBE int `mapstructure:"rate_limit_be" json:"rate_limit_be"`
|
||||
}
|
||||
|
||||
type WebsocketConfig struct {
|
||||
@ -111,20 +96,20 @@ func initViper(configFilePath string) {
|
||||
viper.BindEnv("icons.s3_access_key_id", "ICONS_S3_ACCESS_KEY_ID")
|
||||
viper.BindEnv("icons.s3_access_secret_key", "ICONS_S3_ACCESS_SECRET_KEY")
|
||||
|
||||
viper.BindEnv("security.rate_limit_ip", "SECURITY_RATE_LIMIT_IP")
|
||||
viper.BindEnv("security.rate_limit_be", "SECURITY_RATE_LIMIT_BE")
|
||||
viper.BindEnv("security.rate_limit_mobile", "SECURITY_RATE_LIMIT_MOBILE")
|
||||
|
||||
if configFilePath != "" {
|
||||
viper.SetConfigFile(configFilePath)
|
||||
}
|
||||
|
||||
err := viper.ReadInConfig()
|
||||
|
||||
if err != nil {
|
||||
logging.Fatal("failed to read the configuration file: %s", err)
|
||||
}
|
||||
|
||||
err = viper.Unmarshal(&Config)
|
||||
|
||||
Config.Security.TrustedIP = viper.GetStringSlice("security_trusted_ip")
|
||||
|
||||
if err != nil {
|
||||
logging.Fatal("Can not unmarshal configuration", err)
|
||||
}
|
||||
|
3
go.mod
3
go.mod
@ -10,13 +10,14 @@ require (
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/go-playground/validator/v10 v10.15.5
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/go-redis/redis_rate/v10 v10.0.1
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jaswdr/faker v1.16.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose/v3 v3.15.1
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/viper v1.17.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
|
18
go.sum
18
go.sum
@ -62,6 +62,10 @@ github.com/avast/retry-go/v4 v4.5.0 h1:QoRAZZ90cj5oni2Lsgl2GW8mNTnUCnmpx/iKpwVis
|
||||
github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5Aq1fboC3+I=
|
||||
github.com/aws/aws-sdk-go v1.45.25 h1:c4fLlh5sLdK2DCRTY1z0hyuJZU4ygxX8m1FswL6/nF4=
|
||||
github.com/aws/aws-sdk-go v1.45.25/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
|
||||
github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE=
|
||||
@ -128,8 +132,8 @@ github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91
|
||||
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
|
||||
github.com/go-playground/validator/v10 v10.15.5 h1:LEBecTWb/1j5TNY1YYG2RcOUN3R7NLylN+x8TTueE24=
|
||||
github.com/go-playground/validator/v10 v10.15.5/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo=
|
||||
github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||
@ -289,12 +293,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
@ -308,6 +306,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
|
||||
github.com/pressly/goose/v3 v3.15.1 h1:dKaJ1SdLvS/+HtS8PzFT0KBEtICC1jewLXM+b3emlv8=
|
||||
github.com/pressly/goose/v3 v3.15.1/go.mod h1:0E3Yg/+EwYzO6Rz2P98MlClFgIcoujbVRs575yi3iIM=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
|
||||
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
@ -720,8 +720,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
|
@ -1,17 +1,18 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/rate_limit"
|
||||
"time"
|
||||
)
|
||||
|
||||
var browserExtensionApiBandwidthAbuseThreshold = 100
|
||||
const defaultBrowserExtensionApiBandwidthAbuseThreshold = 100
|
||||
|
||||
func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc {
|
||||
func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
extensionId := c.Param("extension_id")
|
||||
|
||||
@ -21,12 +22,15 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter
|
||||
|
||||
key := fmt.Sprintf("security.api.browser_extension.bandwidth.%s", extensionId)
|
||||
|
||||
limitValue := rateLimitValue
|
||||
if limitValue == 0 {
|
||||
limitValue = defaultBrowserExtensionApiBandwidthAbuseThreshold
|
||||
}
|
||||
rate := rate_limit.Rate{
|
||||
TimeUnit: time.Minute,
|
||||
Limit: browserExtensionApiBandwidthAbuseThreshold,
|
||||
Limit: limitValue,
|
||||
}
|
||||
|
||||
limitReached := rateLimiter.Test(context.Background(), key, rate)
|
||||
limitReached := rateLimiter.Test(c, key, rate)
|
||||
|
||||
if limitReached {
|
||||
logging.WithFields(logging.Fields{
|
||||
@ -34,7 +38,8 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter
|
||||
"uri": c.Request.URL.String(),
|
||||
"browser_extension_id": extensionId,
|
||||
"ip": c.ClientIP(),
|
||||
}).Warning("API potentially abused at Browser Extension scope")
|
||||
}).Warning("API potentially abused at Browser Extension scope, blocking")
|
||||
c.AbortWithStatus(http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/twofas/2fas-server/config"
|
||||
"github.com/twofas/2fas-server/internal/api/browser_extension/adapters"
|
||||
"github.com/twofas/2fas-server/internal/api/browser_extension/app"
|
||||
@ -125,8 +125,8 @@ func NewBrowserExtensionModule(
|
||||
func (m *BrowserExtensionModule) RegisterPublicRoutes(router *gin.Engine) {
|
||||
rateLimiter := rate_limit.New(m.Redis)
|
||||
|
||||
bandwidthAuditMiddleware := apisec.BrowserExtensionBandwidthAuditMiddleware(rateLimiter)
|
||||
iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter)
|
||||
bandwidthAuditMiddleware := apisec.BrowserExtensionBandwidthAuditMiddleware(rateLimiter, m.Config.Security.RateLimitBE)
|
||||
iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitIP)
|
||||
|
||||
publicRouter := router.Group("/")
|
||||
publicRouter.Use(iPAbuseAuditMiddleware)
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/twofas/2fas-server/config"
|
||||
mobile "github.com/twofas/2fas-server/internal/api/mobile/domain"
|
||||
support "github.com/twofas/2fas-server/internal/api/support/domain"
|
||||
|
@ -2,7 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/twofas/2fas-server/config"
|
||||
"github.com/twofas/2fas-server/internal/api/health/ports"
|
||||
)
|
||||
|
@ -1,18 +1,19 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/rate_limit"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var mobileApiBandwidthAbuseThreshold = 100
|
||||
const defaultMobileApiBandwidthAbuseThreshold = 100
|
||||
|
||||
func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc {
|
||||
func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
deviceId := c.Param("device_id")
|
||||
extensionId := c.Param("extension_id")
|
||||
@ -25,13 +26,16 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.Handle
|
||||
fmt.Sprintf("security.api.mobile.bandwidth.%s.%s", deviceId, extensionId),
|
||||
".",
|
||||
)
|
||||
|
||||
limitValue := rateLimitValue
|
||||
if limitValue == 0 {
|
||||
limitValue = defaultMobileApiBandwidthAbuseThreshold
|
||||
}
|
||||
rate := rate_limit.Rate{
|
||||
TimeUnit: time.Minute,
|
||||
Limit: mobileApiBandwidthAbuseThreshold,
|
||||
Limit: limitValue,
|
||||
}
|
||||
|
||||
limitReached := rateLimiter.Test(context.Background(), key, rate)
|
||||
limitReached := rateLimiter.Test(c, key, rate)
|
||||
|
||||
if limitReached {
|
||||
logging.WithFields(logging.Fields{
|
||||
@ -40,7 +44,8 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.Handle
|
||||
"device_id": deviceId,
|
||||
"browser_extension_id": extensionId,
|
||||
"ip": c.ClientIP(),
|
||||
}).Warning("API potentially abused at mobile scope")
|
||||
}).Warning("API potentially abused at mobile scope, blocking")
|
||||
c.AbortWithStatus(http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/twofas/2fas-server/config"
|
||||
browser_extension_adapters "github.com/twofas/2fas-server/internal/api/browser_extension/adapters"
|
||||
"github.com/twofas/2fas-server/internal/api/mobile/adapters"
|
||||
@ -117,8 +117,8 @@ func NewMobileModule(config config.Configuration, gorm *gorm.DB, database *sql.D
|
||||
func (m *MobileModule) RegisterPublicRoutes(router *gin.Engine) {
|
||||
rateLimiter := rate_limit.New(m.Redis)
|
||||
|
||||
bandwidthMobileApiMiddleware := apisec.MobileIpAbuseAuditMiddleware(rateLimiter)
|
||||
iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter)
|
||||
bandwidthMobileApiMiddleware := apisec.MobileIpAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitMobile)
|
||||
iPAbuseAuditMiddleware := security.IPAbuseAuditMiddleware(rateLimiter, m.Config.Security.RateLimitIP)
|
||||
|
||||
publicRouter := router.Group("/")
|
||||
publicRouter.Use(iPAbuseAuditMiddleware)
|
||||
|
@ -2,8 +2,11 @@ package rate_limit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis_rate/v10"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
)
|
||||
|
||||
type Rate struct {
|
||||
@ -18,33 +21,33 @@ type RateLimiter interface {
|
||||
type LimitHandler func()
|
||||
|
||||
type RedisRateLimit struct {
|
||||
Client *redis.Client
|
||||
limiter *redis_rate.Limiter
|
||||
}
|
||||
|
||||
func New(client *redis.Client) RateLimiter {
|
||||
return &RedisRateLimit{Client: client}
|
||||
return &RedisRateLimit{
|
||||
limiter: redis_rate.NewLimiter(client),
|
||||
}
|
||||
}
|
||||
|
||||
// Test returns information if limit has been reached.
|
||||
func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool {
|
||||
counter, err := r.Client.Get(context.Background(), key).Int()
|
||||
|
||||
if err == redis.Nil {
|
||||
r.Client.Set(ctx, key, 1, rate.TimeUnit)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
res, err := r.limiter.Allow(ctx, key, redis_rate.Limit{
|
||||
Rate: rate.Limit,
|
||||
Burst: rate.Limit,
|
||||
Period: rate.TimeUnit,
|
||||
})
|
||||
if err != nil {
|
||||
logging.WithFields(logging.Fields{
|
||||
"type": "security",
|
||||
}).Warnf("Could not check rate limit: %v", err)
|
||||
|
||||
// for now we return that limit has not been reached.
|
||||
return false
|
||||
}
|
||||
|
||||
if counter >= rate.Limit {
|
||||
r.Client.Del(context.Background(), key)
|
||||
|
||||
if res.Allowed <= 0 {
|
||||
// limit has been reached.
|
||||
return true
|
||||
} else {
|
||||
r.Client.Incr(context.Background(), key)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -2,7 +2,8 @@ package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -1,35 +1,41 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/twofas/2fas-server/internal/common/logging"
|
||||
"github.com/twofas/2fas-server/internal/common/rate_limit"
|
||||
"time"
|
||||
)
|
||||
|
||||
var apiBandwidthAbuseThreshold = 100
|
||||
const defaultAPIBandwidthAbuseThreshold = 100
|
||||
|
||||
func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter) gin.HandlerFunc {
|
||||
func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue int) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientIp := c.ClientIP()
|
||||
|
||||
key := fmt.Sprintf("security.api.ip_bandwidth_audit.%s", clientIp)
|
||||
|
||||
limitValue := rateLimitValue
|
||||
if limitValue == 0 {
|
||||
limitValue = defaultAPIBandwidthAbuseThreshold
|
||||
}
|
||||
rate := rate_limit.Rate{
|
||||
TimeUnit: time.Minute,
|
||||
Limit: apiBandwidthAbuseThreshold,
|
||||
Limit: limitValue,
|
||||
}
|
||||
|
||||
limitReached := rateLimiter.Test(context.Background(), key, rate)
|
||||
limitReached := rateLimiter.Test(c, key, rate)
|
||||
|
||||
if limitReached {
|
||||
logging.WithFields(logging.Fields{
|
||||
"type": "security",
|
||||
"uri": c.Request.URL.String(),
|
||||
"ip": c.ClientIP(),
|
||||
}).Warning("API potentially abused by Client IP")
|
||||
}).Warning("API potentially abused by Client IP, blocking")
|
||||
c.AbortWithStatus(http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,24 +1,87 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/twofas/2fas-server/tests"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func Test_MobileApiBandwidthAbuse(t *testing.T) {
|
||||
someId := uuid.New()
|
||||
|
||||
for i := 0; i <= 100; i++ {
|
||||
tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil)
|
||||
noOfRequest := 130
|
||||
noOfWorkers := 20
|
||||
responseCh := make(chan int, noOfRequest)
|
||||
|
||||
eg := errgroup.Group{}
|
||||
eg.SetLimit(noOfWorkers)
|
||||
for i := 0; i < noOfRequest; i++ {
|
||||
eg.Go(func() error {
|
||||
resp := tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil)
|
||||
|
||||
responseCh <- resp.StatusCode
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
require.NoError(t, eg.Wait())
|
||||
close(responseCh)
|
||||
|
||||
var got404, got429 int
|
||||
for code := range responseCh {
|
||||
switch code {
|
||||
case http.StatusNotFound:
|
||||
got404++
|
||||
case http.StatusTooManyRequests:
|
||||
got429++
|
||||
default:
|
||||
t.Fatalf("Unexpected code: %v", code)
|
||||
}
|
||||
}
|
||||
// Default rate limit is 100 per minute.
|
||||
// So we expect around 100 - 404, and around 30 - 429
|
||||
require.InDelta(t, 100, got404, 2.0)
|
||||
require.InDelta(t, 30, got429, 2.0)
|
||||
}
|
||||
|
||||
func Test_BrowserExtensionApiBandwidthAbuse(t *testing.T) {
|
||||
someId := uuid.New()
|
||||
|
||||
for i := 0; i <= 100; i++ {
|
||||
tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil)
|
||||
noOfRequest := 130
|
||||
noOfWorkers := 20
|
||||
responseCh := make(chan int, noOfRequest)
|
||||
|
||||
eg := errgroup.Group{}
|
||||
eg.SetLimit(noOfWorkers)
|
||||
for i := 0; i < noOfRequest; i++ {
|
||||
eg.Go(func() error {
|
||||
resp := tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil)
|
||||
|
||||
responseCh <- resp.StatusCode
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
require.NoError(t, eg.Wait())
|
||||
close(responseCh)
|
||||
|
||||
var got404, got429 int
|
||||
for code := range responseCh {
|
||||
switch code {
|
||||
case http.StatusNotFound:
|
||||
got404++
|
||||
case http.StatusTooManyRequests:
|
||||
got429++
|
||||
default:
|
||||
t.Fatalf("Unexpected code: %v", code)
|
||||
}
|
||||
}
|
||||
// Default rate limit is 100 per minute.
|
||||
// So we expect around 100 - 404, and around 30 - 429
|
||||
require.InDelta(t, 100, got404, 2.0)
|
||||
require.InDelta(t, 30, got429, 2.0)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user