Skip to content

Commit 6268b15

Browse files
authored
Merge pull request #50 from stinovlas/add-cluster-schedule-source
2 parents caeccfb + 05cf89d commit 6268b15

File tree

3 files changed

+234
-3
lines changed

3 files changed

+234
-3
lines changed

taskiq_redis/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
)
66
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
77
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
8-
from taskiq_redis.schedule_source import RedisScheduleSource
8+
from taskiq_redis.schedule_source import (
9+
RedisClusterScheduleSource,
10+
RedisScheduleSource,
11+
)
912

1013
__all__ = [
1114
"RedisAsyncClusterResultBackend",
@@ -14,4 +17,5 @@
1417
"PubSubBroker",
1518
"ListQueueClusterBroker",
1619
"RedisScheduleSource",
20+
"RedisClusterScheduleSource",
1721
]

taskiq_redis/schedule_source.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, List, Optional
22

3-
from redis.asyncio import ConnectionPool, Redis
3+
from redis.asyncio import ConnectionPool, Redis, RedisCluster
44
from taskiq import ScheduleSource
55
from taskiq.abc.serializer import TaskiqSerializer
66
from taskiq.compat import model_dump, model_validate
@@ -95,3 +95,82 @@ async def post_send(self, task: ScheduledTask) -> None:
9595
async def shutdown(self) -> None:
9696
"""Shut down the schedule source."""
9797
await self.connection_pool.disconnect()
98+
99+
100+
class RedisClusterScheduleSource(ScheduleSource):
101+
"""
102+
Source of schedules for redis cluster.
103+
104+
This class allows you to store schedules in redis.
105+
Also it supports dynamic schedules.
106+
107+
:param url: url to redis cluster.
108+
:param prefix: prefix for redis schedule keys.
109+
:param buffer_size: buffer size for redis scan.
110+
This is how many keys will be fetched at once.
111+
:param max_connection_pool_size: maximum number of connections in pool.
112+
:param serializer: serializer for data.
113+
:param connection_kwargs: additional arguments for RedisCluster.
114+
"""
115+
116+
def __init__(
117+
self,
118+
url: str,
119+
prefix: str = "schedule",
120+
buffer_size: int = 50,
121+
serializer: Optional[TaskiqSerializer] = None,
122+
**connection_kwargs: Any,
123+
) -> None:
124+
self.prefix = prefix
125+
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
126+
url,
127+
**connection_kwargs,
128+
)
129+
self.buffer_size = buffer_size
130+
if serializer is None:
131+
serializer = PickleSerializer()
132+
self.serializer = serializer
133+
134+
async def delete_schedule(self, schedule_id: str) -> None:
135+
"""Remove schedule by id."""
136+
await self.redis.delete(f"{self.prefix}:{schedule_id}") # type: ignore[attr-defined]
137+
138+
async def add_schedule(self, schedule: ScheduledTask) -> None:
139+
"""
140+
Add schedule to redis.
141+
142+
:param schedule: schedule to add.
143+
:param schedule_id: schedule id.
144+
"""
145+
await self.redis.set( # type: ignore[attr-defined]
146+
f"{self.prefix}:{schedule.schedule_id}",
147+
self.serializer.dumpb(model_dump(schedule)),
148+
)
149+
150+
async def get_schedules(self) -> List[ScheduledTask]:
151+
"""
152+
Get all schedules from redis.
153+
154+
This method is used by scheduler to get all schedules.
155+
156+
:return: list of schedules.
157+
"""
158+
schedules = []
159+
buffer = []
160+
async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined]
161+
buffer.append(key)
162+
if len(buffer) >= self.buffer_size:
163+
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
164+
buffer = []
165+
if buffer:
166+
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
167+
return [
168+
model_validate(ScheduledTask, self.serializer.loadb(schedule))
169+
for schedule in schedules
170+
if schedule
171+
]
172+
173+
async def post_send(self, task: ScheduledTask) -> None:
174+
"""Delete a task after it's completed."""
175+
if task.time is not None:
176+
await self.delete_schedule(task.schedule_id)

tests/test_schedule_source.py

+149-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import datetime as dt
12
import uuid
23

34
import pytest
45
from taskiq import ScheduledTask
56

6-
from taskiq_redis import RedisScheduleSource
7+
from taskiq_redis import RedisClusterScheduleSource, RedisScheduleSource
78

89

910
@pytest.mark.anyio
@@ -56,6 +57,153 @@ async def test_post_run_cron(redis_url: str) -> None:
5657
cron="* * * * *",
5758
)
5859
await source.add_schedule(schedule)
60+
assert await source.get_schedules() == [schedule]
61+
await source.post_send(schedule)
62+
assert await source.get_schedules() == [schedule]
63+
await source.shutdown()
64+
65+
66+
@pytest.mark.anyio
67+
async def test_post_run_time(redis_url: str) -> None:
68+
prefix = uuid.uuid4().hex
69+
source = RedisScheduleSource(redis_url, prefix=prefix)
70+
schedule = ScheduledTask(
71+
task_name="test_task",
72+
labels={},
73+
args=[],
74+
kwargs={},
75+
time=dt.datetime(2000, 1, 1),
76+
)
77+
await source.add_schedule(schedule)
78+
assert await source.get_schedules() == [schedule]
79+
await source.post_send(schedule)
80+
assert await source.get_schedules() == []
81+
await source.shutdown()
82+
83+
84+
@pytest.mark.anyio
85+
async def test_buffer(redis_url: str) -> None:
86+
prefix = uuid.uuid4().hex
87+
source = RedisScheduleSource(redis_url, prefix=prefix, buffer_size=1)
88+
schedule1 = ScheduledTask(
89+
task_name="test_task1",
90+
labels={},
91+
args=[],
92+
kwargs={},
93+
cron="* * * * *",
94+
)
95+
schedule2 = ScheduledTask(
96+
task_name="test_task2",
97+
labels={},
98+
args=[],
99+
kwargs={},
100+
cron="* * * * *",
101+
)
102+
await source.add_schedule(schedule1)
103+
await source.add_schedule(schedule2)
104+
schedules = await source.get_schedules()
105+
assert len(schedules) == 2
106+
assert schedule1 in schedules
107+
assert schedule2 in schedules
108+
await source.shutdown()
109+
110+
111+
@pytest.mark.anyio
112+
async def test_cluster_set_schedule(redis_cluster_url: str) -> None:
113+
prefix = uuid.uuid4().hex
114+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
115+
schedule = ScheduledTask(
116+
task_name="test_task",
117+
labels={},
118+
args=[],
119+
kwargs={},
120+
cron="* * * * *",
121+
)
122+
await source.add_schedule(schedule)
123+
schedules = await source.get_schedules()
124+
assert schedules == [schedule]
125+
await source.shutdown()
126+
127+
128+
@pytest.mark.anyio
129+
async def test_cluster_delete_schedule(redis_cluster_url: str) -> None:
130+
prefix = uuid.uuid4().hex
131+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
132+
schedule = ScheduledTask(
133+
task_name="test_task",
134+
labels={},
135+
args=[],
136+
kwargs={},
137+
cron="* * * * *",
138+
)
139+
await source.add_schedule(schedule)
59140
schedules = await source.get_schedules()
60141
assert schedules == [schedule]
142+
await source.delete_schedule(schedule.schedule_id)
143+
schedules = await source.get_schedules()
144+
# Schedules are empty.
145+
assert not schedules
146+
await source.shutdown()
147+
148+
149+
@pytest.mark.anyio
150+
async def test_cluster_post_run_cron(redis_cluster_url: str) -> None:
151+
prefix = uuid.uuid4().hex
152+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
153+
schedule = ScheduledTask(
154+
task_name="test_task",
155+
labels={},
156+
args=[],
157+
kwargs={},
158+
cron="* * * * *",
159+
)
160+
await source.add_schedule(schedule)
161+
assert await source.get_schedules() == [schedule]
162+
await source.post_send(schedule)
163+
assert await source.get_schedules() == [schedule]
164+
await source.shutdown()
165+
166+
167+
@pytest.mark.anyio
168+
async def test_cluster_post_run_time(redis_cluster_url: str) -> None:
169+
prefix = uuid.uuid4().hex
170+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
171+
schedule = ScheduledTask(
172+
task_name="test_task",
173+
labels={},
174+
args=[],
175+
kwargs={},
176+
time=dt.datetime(2000, 1, 1),
177+
)
178+
await source.add_schedule(schedule)
179+
assert await source.get_schedules() == [schedule]
180+
await source.post_send(schedule)
181+
assert await source.get_schedules() == []
182+
await source.shutdown()
183+
184+
185+
@pytest.mark.anyio
186+
async def test_cluster_buffer(redis_cluster_url: str) -> None:
187+
prefix = uuid.uuid4().hex
188+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix, buffer_size=1)
189+
schedule1 = ScheduledTask(
190+
task_name="test_task1",
191+
labels={},
192+
args=[],
193+
kwargs={},
194+
cron="* * * * *",
195+
)
196+
schedule2 = ScheduledTask(
197+
task_name="test_task2",
198+
labels={},
199+
args=[],
200+
kwargs={},
201+
cron="* * * * *",
202+
)
203+
await source.add_schedule(schedule1)
204+
await source.add_schedule(schedule2)
205+
schedules = await source.get_schedules()
206+
assert len(schedules) == 2
207+
assert schedule1 in schedules
208+
assert schedule2 in schedules
61209
await source.shutdown()

0 commit comments

Comments
 (0)