Skip to content

Commit

Permalink
asd
Browse files Browse the repository at this point in the history
  • Loading branch information
dbarrosop committed Aug 13, 2024
1 parent 3c16c83 commit 857d6a4
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 35 deletions.
76 changes: 54 additions & 22 deletions go/middleware/rate_limit.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"math"
"net/http"
"slices"
"strings"
Expand All @@ -9,26 +10,46 @@ import (
"github.com/gin-gonic/gin"
)

const (
replicas = 1
)

type BurstBucket struct {
recoverRateMs float64
maxBurst int64
currentBurst int64
lastEntry time.Time
replicas int64 // for future use
interval time.Duration
maxBurst int64
currentBurst int64
lastEntry time.Time
}

func NewBurstBucket(maxBurst int64, interval time.Duration) *BurstBucket {
func NewBurstBucket(maxBurst int64, interval time.Duration, replicas int64) *BurstBucket {
return &BurstBucket{
recoverRateMs: float64(interval.Milliseconds()) / float64(maxBurst),
maxBurst: maxBurst,
currentBurst: 0,
lastEntry: time.Time{},
replicas: replicas,
interval: interval,
maxBurst: maxBurst,
currentBurst: 0,
lastEntry: time.Time{},
}
}

func (b *BurstBucket) Clean() {
func (b *BurstBucket) SetReplicas(replicas int64) {
b.currentBurst = int64(
math.Floor(float64(b.currentBurst) * (float64(b.replicas) / float64(replicas))))
b.replicas = replicas
}

func (b *BurstBucket) MaxBurst() int64 {
return b.maxBurst / b.replicas
}

func (b *BurstBucket) Recoverable() int64 {
elapsedMs := float64(time.Since(b.lastEntry).Milliseconds())
recoveryRateMs := float64(b.interval.Milliseconds()) / float64(b.MaxBurst())
return int64(elapsedMs / recoveryRateMs)
}

currentBurst := b.currentBurst - int64(elapsedMs/b.recoverRateMs)
func (b *BurstBucket) Clean() {
currentBurst := b.currentBurst - b.Recoverable()

if currentBurst < 0 {
b.currentBurst = 0
Expand All @@ -46,7 +67,7 @@ func (b *BurstBucket) Add() bool {
b.Clean()

// if we are at the limit, we can't add more
if b.currentBurst >= b.maxBurst {
if b.currentBurst >= b.MaxBurst() {
return false
}

Expand All @@ -59,17 +80,27 @@ func (b *BurstBucket) Add() bool {
type RateLimiter struct {
maxBurst int64
recoveryRate time.Duration
replicas int64
buckets map[string]*BurstBucket
}

func NewRateLimiter(maxBurst int64, recoveryRate time.Duration) *RateLimiter {
func NewRateLimiter(maxBurst int64, recoveryRate time.Duration, replicas int64) *RateLimiter {
return &RateLimiter{
replicas: replicas,
maxBurst: maxBurst,
recoveryRate: recoveryRate,
buckets: make(map[string]*BurstBucket),
}
}

func (r *RateLimiter) SetReplicas(replicas int64) {
r.replicas = replicas

for _, bucket := range r.buckets {
bucket.SetReplicas(replicas)
}
}

func (r *RateLimiter) Clean() {
for key, bucket := range r.buckets {
if bucket.CurrentBurst() == 0 {
Expand All @@ -81,7 +112,7 @@ func (r *RateLimiter) Clean() {
func (r *RateLimiter) Add(key string) bool {
bucket, ok := r.buckets[key]
if !ok {
bucket = NewBurstBucket(r.maxBurst, r.recoveryRate)
bucket = NewBurstBucket(r.maxBurst, r.recoveryRate, r.replicas)
r.buckets[key] = bucket
}

Expand Down Expand Up @@ -117,7 +148,8 @@ func sendsSMS(path string) bool {
// endpnits that can be brute forced.
func bruteForceProtected(path string) bool {
return strings.HasPrefix(path, "/signin") ||
strings.HasSuffix(path, "/verify")
strings.HasSuffix(path, "/verify") ||
strings.HasSuffix(path, "/otp")
}

// #s.
Expand Down Expand Up @@ -163,22 +195,22 @@ func RateLimit( //nolint:cyclop,funlen,gocognit
) gin.HandlerFunc {
lastClean := time.Now()

perUserRL := NewRateLimiter(globalLimit, globalInterval)
perUserRL := NewRateLimiter(globalLimit, globalInterval, replicas)

var globalEmailRL *BurstBucket
var perUserEmailRL *RateLimiter
if emailIsGlobal {
globalEmailRL = NewBurstBucket(emailLimit, emailInterval)
globalEmailRL = NewBurstBucket(emailLimit, emailInterval, replicas)
} else {
perUserEmailRL = NewRateLimiter(emailLimit, emailInterval)
perUserEmailRL = NewRateLimiter(emailLimit, emailInterval, replicas)
}

globalSMSRL := NewRateLimiter(smsLimit, smsInterval)
perUserBruteForceRL := NewRateLimiter(bruteForceLimit, bruteForceInterval)
perUser#sRL := NewBurstBucket(#sLimit, #sInterval)
globalSMSRL := NewRateLimiter(smsLimit, smsInterval, replicas)
perUserBruteForceRL := NewRateLimiter(bruteForceLimit, bruteForceInterval, replicas)
perUser#sRL := NewBurstBucket(#sLimit, #sInterval, replicas)

return func(ctx *gin.Context) {
if time.Since(lastClean) > 5*time.Minute {
if time.Since(lastClean) > 1*time.Minute {
perUserRL.Clean()
if globalEmailRL != nil {
globalEmailRL.Clean()
Expand Down
Loading

0 comments on commit 857d6a4

Please # to comment.