-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrate_limiter.py
81 lines (67 loc) · 2.39 KB
/
rate_limiter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from backend.utils import setup_logging
from backend.redis_wrapper import RedisHashTable, RedisQueue
from redis import Redis
from threading import Thread
from multiprocessing import Process
import time
import logging
class RateLimiter:
"""Rate limiter class"""
def __init__(self, redis_handle: Redis):
self.logger = None
setup_logging(self, logging.INFO)
self.redis_handle: Redis = redis_handle
self.rate_limiter_queues: dict[RedisQueue] = dict()
self.threads: list[Process] = list()
def add_limiter(self, limiter_name: str, rate: int, lim_max_token: int = 1200):
"""add limiter configs
Args:
limiter_name (str): _description_
rate (int): _description_
lim_max_token (int, optional): _description_. Defaults to 1200.
"""
_queue = RedisQueue(self.redis_handle, limiter_name)
self.rate_limiter_queues[limiter_name] = _queue
t = Thread(
target=self.drip_token,
args=(
limiter_name,
rate,
lim_max_token,
),
daemon=True,
)
self.threads.append(t)
def init_limiters(self):
for t in self.threads:
t.start()
def drip_token(self, limiter_name: str, rate: int, lim_max_token: int = 1200):
"""drip token to queue
Args:
limiter_name (str): _description_
rate (int): _description_
lim_max_token (int, optional): _description_. Defaults to 1200.
"""
api_limiter_queue: RedisQueue = self.get_limiter_queue(limiter_name)
api_token = "dummy"
while True:
if api_limiter_queue.length() < lim_max_token:
api_limiter_queue.enqueue(api_token)
time.sleep(rate)
def have_token(self, limiter_name: str) -> bool:
"""retrieve token from queue
Args:
limiter_name (str): _description_
Returns:
bool: _description_
"""
api_limiter_queue: RedisQueue = self.get_limiter_queue(limiter_name)
qlen = api_limiter_queue.length()
self.logger.debug(qlen)
if qlen > 0:
api_limiter_queue.dequeue()
return True
else:
return False
def get_limiter_queue(self, limiter_name: str):
return self.rate_limiter_queues[limiter_name]