2fas-server/internal/common/rate_limit/redis_rate_limit.go

54 lines
1.1 KiB
Go
Raw Normal View History

2022-12-31 10:22:38 +01:00
package rate_limit
import (
"context"
"time"
2024-01-02 09:48:34 +01:00
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
"github.com/twofas/2fas-server/internal/common/logging"
2022-12-31 10:22:38 +01:00
)
type Rate struct {
TimeUnit time.Duration
Limit int
}
type RateLimiter interface {
Test(ctx context.Context, key string, rate Rate) bool
}
type LimitHandler func()
type RedisRateLimit struct {
2024-01-02 09:48:34 +01:00
limiter *redis_rate.Limiter
2022-12-31 10:22:38 +01:00
}
func New(client *redis.Client) RateLimiter {
2024-01-02 09:48:34 +01:00
return &RedisRateLimit{
limiter: redis_rate.NewLimiter(client),
}
2022-12-31 10:22:38 +01:00
}
2024-01-02 09:48:34 +01:00
// Test returns information if limit has been reached.
2022-12-31 10:22:38 +01:00
func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool {
2024-01-02 09:48:34 +01:00
res, err := r.limiter.Allow(ctx, key, redis_rate.Limit{
Rate: rate.Limit,
Burst: rate.Limit,
Period: rate.TimeUnit,
})
2022-12-31 10:22:38 +01:00
if err != nil {
2024-01-02 09:48:34 +01:00
logging.WithFields(logging.Fields{
"type": "security",
}).Warnf("Could not check rate limit: %v", err)
// for now we return that limit has not been reached.
2022-12-31 10:22:38 +01:00
return false
}
2024-01-02 09:48:34 +01:00
if res.Allowed <= 0 {
// limit has been reached.
2022-12-31 10:22:38 +01:00
return true
}
return false
}