From 857d6a46634b58a8d2161d3f69f5b8213de83529 Mon Sep 17 00:00:00 2001 From: David Barroso Date: Tue, 13 Aug 2024 10:59:57 +0200 Subject: [PATCH] asd --- go/middleware/rate_limit.go | 76 ++++++++++++------ go/middleware/rate_limit_test.go | 127 +++++++++++++++++++++++++++---- 2 files changed, 168 insertions(+), 35 deletions(-) diff --git a/go/middleware/rate_limit.go b/go/middleware/rate_limit.go index 4305ad83..ffae5e70 100644 --- a/go/middleware/rate_limit.go +++ b/go/middleware/rate_limit.go @@ -1,6 +1,7 @@ package middleware import ( + "math" "net/http" "slices" "strings" @@ -9,26 +10,46 @@ import ( "github.com/gin-gonic/gin" ) +const ( + replicas = 1 +) + type BurstBucket struct { - recoverRateMs float64 - maxBurst int64 - currentBurst int64 - lastEntry time.Time + replicas int64 // for future use + interval time.Duration + maxBurst int64 + currentBurst int64 + lastEntry time.Time } -func NewBurstBucket(maxBurst int64, interval time.Duration) *BurstBucket { +func NewBurstBucket(maxBurst int64, interval time.Duration, replicas int64) *BurstBucket { return &BurstBucket{ - recoverRateMs: float64(interval.Milliseconds()) / float64(maxBurst), - maxBurst: maxBurst, - currentBurst: 0, - lastEntry: time.Time{}, + replicas: replicas, + interval: interval, + maxBurst: maxBurst, + currentBurst: 0, + lastEntry: time.Time{}, } } -func (b *BurstBucket) Clean() { +func (b *BurstBucket) SetReplicas(replicas int64) { + b.currentBurst = int64( + math.Floor(float64(b.currentBurst) * (float64(b.replicas) / float64(replicas)))) + b.replicas = replicas +} + +func (b *BurstBucket) MaxBurst() int64 { + return b.maxBurst / b.replicas +} + +func (b *BurstBucket) Recoverable() int64 { elapsedMs := float64(time.Since(b.lastEntry).Milliseconds()) + recoveryRateMs := float64(b.interval.Milliseconds()) / float64(b.MaxBurst()) + return int64(elapsedMs / recoveryRateMs) +} - currentBurst := b.currentBurst - int64(elapsedMs/b.recoverRateMs) +func (b *BurstBucket) Clean() { + currentBurst := b.currentBurst - b.Recoverable() if currentBurst < 0 { b.currentBurst = 0 @@ -46,7 +67,7 @@ func (b *BurstBucket) Add() bool { b.Clean() // if we are at the limit, we can't add more - if b.currentBurst >= b.maxBurst { + if b.currentBurst >= b.MaxBurst() { return false } @@ -59,17 +80,27 @@ func (b *BurstBucket) Add() bool { type RateLimiter struct { maxBurst int64 recoveryRate time.Duration + replicas int64 buckets map[string]*BurstBucket } -func NewRateLimiter(maxBurst int64, recoveryRate time.Duration) *RateLimiter { +func NewRateLimiter(maxBurst int64, recoveryRate time.Duration, replicas int64) *RateLimiter { return &RateLimiter{ + replicas: replicas, maxBurst: maxBurst, recoveryRate: recoveryRate, buckets: make(map[string]*BurstBucket), } } +func (r *RateLimiter) SetReplicas(replicas int64) { + r.replicas = replicas + + for _, bucket := range r.buckets { + bucket.SetReplicas(replicas) + } +} + func (r *RateLimiter) Clean() { for key, bucket := range r.buckets { if bucket.CurrentBurst() == 0 { @@ -81,7 +112,7 @@ func (r *RateLimiter) Clean() { func (r *RateLimiter) Add(key string) bool { bucket, ok := r.buckets[key] if !ok { - bucket = NewBurstBucket(r.maxBurst, r.recoveryRate) + bucket = NewBurstBucket(r.maxBurst, r.recoveryRate, r.replicas) r.buckets[key] = bucket } @@ -117,7 +148,8 @@ func sendsSMS(path string) bool { // endpnits that can be brute forced. func bruteForceProtected(path string) bool { return strings.HasPrefix(path, "/signin") || - strings.HasSuffix(path, "/verify") + strings.HasSuffix(path, "/verify") || + strings.HasSuffix(path, "/otp") } // signups. @@ -163,22 +195,22 @@ func RateLimit( //nolint:cyclop,funlen,gocognit ) gin.HandlerFunc { lastClean := time.Now() - perUserRL := NewRateLimiter(globalLimit, globalInterval) + perUserRL := NewRateLimiter(globalLimit, globalInterval, replicas) var globalEmailRL *BurstBucket var perUserEmailRL *RateLimiter if emailIsGlobal { - globalEmailRL = NewBurstBucket(emailLimit, emailInterval) + globalEmailRL = NewBurstBucket(emailLimit, emailInterval, replicas) } else { - perUserEmailRL = NewRateLimiter(emailLimit, emailInterval) + perUserEmailRL = NewRateLimiter(emailLimit, emailInterval, replicas) } - globalSMSRL := NewRateLimiter(smsLimit, smsInterval) - perUserBruteForceRL := NewRateLimiter(bruteForceLimit, bruteForceInterval) - perUserSignupsRL := NewBurstBucket(signupsLimit, signupsInterval) + globalSMSRL := NewRateLimiter(smsLimit, smsInterval, replicas) + perUserBruteForceRL := NewRateLimiter(bruteForceLimit, bruteForceInterval, replicas) + perUserSignupsRL := NewBurstBucket(signupsLimit, signupsInterval, replicas) return func(ctx *gin.Context) { - if time.Since(lastClean) > 5*time.Minute { + if time.Since(lastClean) > 1*time.Minute { perUserRL.Clean() if globalEmailRL != nil { globalEmailRL.Clean() diff --git a/go/middleware/rate_limit_test.go b/go/middleware/rate_limit_test.go index b37aa75d..9a626371 100644 --- a/go/middleware/rate_limit_test.go +++ b/go/middleware/rate_limit_test.go @@ -3,72 +3,158 @@ package middleware //nolint:testpackage import ( "testing" "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) func TestBurstBucket(t *testing.T) { t.Parallel() cases := []struct { - name string - bucket func() *BurstBucket - added bool + name string + bucket func() *BurstBucket + added bool + expectedBucket *BurstBucket }{ { name: "empty", bucket: func() *BurstBucket { - return NewBurstBucket(10, time.Second) + return NewBurstBucket(10, time.Second, 1) }, added: true, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: time.Second, + maxBurst: 10, + currentBurst: 1, + lastEntry: time.Time{}, + }, }, { name: "full", bucket: func() *BurstBucket { - bucket := NewBurstBucket(10, time.Second) + bucket := NewBurstBucket(10, time.Second, 1) bucket.currentBurst = 10 bucket.lastEntry = time.Now() return bucket }, added: false, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: time.Second, + maxBurst: 10, + currentBurst: 10, + lastEntry: time.Time{}, + }, }, { name: "5/s - after 198ms", bucket: func() *BurstBucket { - bucket := NewBurstBucket(5, time.Second) + bucket := NewBurstBucket(5, time.Second, 1) bucket.lastEntry = time.Now().Add(-198 * time.Millisecond) bucket.currentBurst = 5 return bucket }, added: false, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: time.Second, + maxBurst: 5, + currentBurst: 5, + lastEntry: time.Time{}, + }, }, { name: "5/s - after 200ms", bucket: func() *BurstBucket { - bucket := NewBurstBucket(5, time.Second) + bucket := NewBurstBucket(5, time.Second, 1) bucket.lastEntry = time.Now().Add(-200 * time.Millisecond) bucket.currentBurst = 5 return bucket }, added: true, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: time.Second, + maxBurst: 5, + currentBurst: 5, + lastEntry: time.Time{}, + }, }, { name: "60/5h - after 4.9m", bucket: func() *BurstBucket { - bucket := NewBurstBucket(60, 5*time.Hour) + bucket := NewBurstBucket(60, 5*time.Hour, 1) bucket.lastEntry = time.Now().Add(-4*time.Minute - 59*time.Second) bucket.currentBurst = 60 return bucket }, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: 5 * time.Hour, + maxBurst: 60, + currentBurst: 60, + lastEntry: time.Time{}, + }, added: false, }, { name: "60/5h - after 5m", bucket: func() *BurstBucket { - bucket := NewBurstBucket(60, 5*time.Hour) + bucket := NewBurstBucket(60, 5*time.Hour, 1) bucket.lastEntry = time.Now().Add(-5 * time.Minute) bucket.currentBurst = 60 return bucket }, added: true, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: 5 * time.Hour, + maxBurst: 60, + currentBurst: 60, + lastEntry: time.Time{}, + }, + }, + { + name: "full from 1 to 2 replicas", + bucket: func() *BurstBucket { + bucket := NewBurstBucket(10, time.Second, 1) + bucket.currentBurst = 10 + bucket.lastEntry = time.Now() + + bucket.SetReplicas(2) + + return bucket + }, + added: false, + expectedBucket: &BurstBucket{ + replicas: 2, + interval: time.Second, + maxBurst: 10, + currentBurst: 5, + lastEntry: time.Time{}, + }, + }, + { + name: "full from 2 to 1 replicas", + bucket: func() *BurstBucket { + bucket := NewBurstBucket(10, time.Second, 2) + bucket.currentBurst = 5 + bucket.lastEntry = time.Now() + + bucket.SetReplicas(1) + + return bucket + }, + added: false, + expectedBucket: &BurstBucket{ + replicas: 1, + interval: time.Second, + maxBurst: 10, + currentBurst: 10, + lastEntry: time.Time{}, + }, }, } @@ -81,6 +167,15 @@ func TestBurstBucket(t *testing.T) { if got != tc.added { t.Errorf("Add() = %v; want %v", got, tc.added) } + + //nolint:exhaustruct + opts := []cmp.Option{ + cmp.AllowUnexported(BurstBucket{}), + cmpopts.IgnoreFields(BurstBucket{}, "lastEntry"), + } + if diff := cmp.Diff(tc.expectedBucket, bucket, opts...); diff != "" { + t.Errorf("Add() mismatch (-want +got):\n%s", diff) + } }) } } @@ -97,6 +192,7 @@ func TestRateLimiterAdd(t *testing.T) { name: "empty", rateLimiter: func() RateLimiter { return RateLimiter{ + replicas: 1, recoveryRate: time.Minute, maxBurst: 10, buckets: make(map[string]*BurstBucket), @@ -107,11 +203,12 @@ func TestRateLimiterAdd(t *testing.T) { { name: "full", rateLimiter: func() RateLimiter { - bucket := NewBurstBucket(10, time.Minute) + bucket := NewBurstBucket(10, time.Minute, 1) bucket.lastEntry = time.Now() bucket.currentBurst = 10 return RateLimiter{ + replicas: 1, recoveryRate: time.Minute, maxBurst: 10, buckets: map[string]*BurstBucket{ @@ -124,11 +221,12 @@ func TestRateLimiterAdd(t *testing.T) { { name: "recovered", rateLimiter: func() RateLimiter { - bucket := NewBurstBucket(10, time.Minute) + bucket := NewBurstBucket(10, time.Minute, 1) bucket.lastEntry = time.Now().Add(-1 * time.Minute) bucket.currentBurst = 10 return RateLimiter{ + replicas: 1, recoveryRate: time.Minute, maxBurst: 10, buckets: map[string]*BurstBucket{ @@ -163,6 +261,7 @@ func TestRateLimiterClean(t *testing.T) { //nolint:tparallel,paralleltest name: "empty", rateLimiter: func() RateLimiter { return RateLimiter{ + replicas: 1, recoveryRate: time.Minute, maxBurst: 10, buckets: make(map[string]*BurstBucket), @@ -173,11 +272,12 @@ func TestRateLimiterClean(t *testing.T) { //nolint:tparallel,paralleltest { name: "with one", rateLimiter: func() RateLimiter { - bucket := NewBurstBucket(10, time.Minute) + bucket := NewBurstBucket(10, time.Minute, 1) bucket.lastEntry = time.Now() bucket.currentBurst = 10 return RateLimiter{ + replicas: 1, recoveryRate: time.Minute, maxBurst: 10, buckets: map[string]*BurstBucket{ @@ -190,11 +290,12 @@ func TestRateLimiterClean(t *testing.T) { //nolint:tparallel,paralleltest { name: "with one expired", rateLimiter: func() RateLimiter { - bucket := NewBurstBucket(10, time.Minute) + bucket := NewBurstBucket(10, time.Minute, 1) bucket.lastEntry = time.Now().Add(-1 * time.Minute) bucket.currentBurst = 10 return RateLimiter{ + replicas: 1, recoveryRate: time.Minute, maxBurst: 10, buckets: map[string]*BurstBucket{