-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathratelimit.go
112 lines (97 loc) · 2.55 KB
/
ratelimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package ratelimit
import (
"context"
_ "embed"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
xrate "golang.org/x/time/rate"
)
const pingInterval = time.Millisecond * 100
var (
//go:embed tokenscript.lua
luaScript string
tokenScript = redis.NewScript(luaScript)
)
type Bucket struct {
rdb rediser
key string
capacity int
rate int
redisAlive uint32
rescueLock sync.Mutex
rescueLimiter *xrate.Limiter
monitoring bool
}
type rediser interface {
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
Ping(ctx context.Context) *redis.StatusCmd
}
func NewBucket(rdb rediser, key string, rate, capacity int) *Bucket {
if rate <= 0 {
panic("rate must be greater than 0")
}
if capacity < 0 {
panic("capacity must be greater than or equal to 0")
}
bucket := &Bucket{
rdb: rdb,
key: key,
capacity: capacity,
rate: rate,
redisAlive: 1,
rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), capacity),
}
if rdb == nil {
bucket.redisAlive = 0
}
return bucket
}
func (b *Bucket) Take() bool {
return b.TakeN(1)
}
func (b *Bucket) TakeN(count int) bool {
if atomic.LoadUint32(&b.redisAlive) == 0 {
return b.rescueLimiter.AllowN(time.Now(), count)
}
result, err := tokenScript.Run(context.Background(), b.rdb, []string{b.key}, b.rate, b.capacity, time.Now().Unix(), count).Int()
if err != nil {
if err == redis.Nil || err == context.Canceled || err == context.DeadlineExceeded {
return false
}
b.monitor()
return b.rescueLimiter.AllowN(time.Now(), count)
}
return result > 0
}
func (b *Bucket) monitor() {
b.rescueLock.Lock()
defer b.rescueLock.Unlock()
if b.monitoring {
return
}
b.monitoring = true
atomic.StoreUint32(&b.redisAlive, 0)
go b.waitForRedis()
}
func (b *Bucket) waitForRedis() {
ticker := time.NewTicker(pingInterval)
defer func() {
ticker.Stop()
b.rescueLock.Lock()
b.monitoring = false
b.rescueLock.Unlock()
}()
for range ticker.C {
if b.rdb.Ping(context.Background()).Err() == nil {
atomic.StoreUint32(&b.redisAlive, 1)
return
}
}
}