2022-12-31 10:22:38 +01:00
|
|
|
package rate_limit
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"time"
|
2024-01-02 09:48:34 +01:00
|
|
|
|
|
|
|
"github.com/go-redis/redis_rate/v10"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
"github.com/twofas/2fas-server/internal/common/logging"
|
2022-12-31 10:22:38 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
type Rate struct {
|
|
|
|
TimeUnit time.Duration
|
|
|
|
Limit int
|
|
|
|
}
|
|
|
|
|
|
|
|
type RateLimiter interface {
|
|
|
|
Test(ctx context.Context, key string, rate Rate) bool
|
|
|
|
}
|
|
|
|
|
|
|
|
type LimitHandler func()
|
|
|
|
|
|
|
|
type RedisRateLimit struct {
|
2024-01-02 09:48:34 +01:00
|
|
|
limiter *redis_rate.Limiter
|
2022-12-31 10:22:38 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func New(client *redis.Client) RateLimiter {
|
2024-01-02 09:48:34 +01:00
|
|
|
return &RedisRateLimit{
|
|
|
|
limiter: redis_rate.NewLimiter(client),
|
|
|
|
}
|
2022-12-31 10:22:38 +01:00
|
|
|
}
|
|
|
|
|
2024-01-02 09:48:34 +01:00
|
|
|
// Test returns information if limit has been reached.
|
2022-12-31 10:22:38 +01:00
|
|
|
func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool {
|
2024-01-02 09:48:34 +01:00
|
|
|
res, err := r.limiter.Allow(ctx, key, redis_rate.Limit{
|
|
|
|
Rate: rate.Limit,
|
|
|
|
Burst: rate.Limit,
|
|
|
|
Period: rate.TimeUnit,
|
|
|
|
})
|
2022-12-31 10:22:38 +01:00
|
|
|
if err != nil {
|
2024-03-16 19:05:21 +01:00
|
|
|
logging.FromContext(ctx).WithFields(logging.Fields{
|
2024-01-02 09:48:34 +01:00
|
|
|
"type": "security",
|
|
|
|
}).Warnf("Could not check rate limit: %v", err)
|
|
|
|
|
|
|
|
// for now we return that limit has not been reached.
|
2022-12-31 10:22:38 +01:00
|
|
|
return false
|
|
|
|
}
|
2024-01-02 09:48:34 +01:00
|
|
|
if res.Allowed <= 0 {
|
|
|
|
// limit has been reached.
|
2022-12-31 10:22:38 +01:00
|
|
|
return true
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|