Skip to content

Commit 4074e81

Browse files
committed
Add cluster schedule source
1 parent caeccfb commit 4074e81

File tree

3 files changed

+141
-3
lines changed

3 files changed

+141
-3
lines changed

Diff for: taskiq_redis/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
)
66
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
77
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
8-
from taskiq_redis.schedule_source import RedisScheduleSource
8+
from taskiq_redis.schedule_source import (
9+
RedisClusterScheduleSource,
10+
RedisScheduleSource,
11+
)
912

1013
__all__ = [
1114
"RedisAsyncClusterResultBackend",
@@ -14,4 +17,5 @@
1417
"PubSubBroker",
1518
"ListQueueClusterBroker",
1619
"RedisScheduleSource",
20+
"RedisClusterScheduleSource",
1721
]

Diff for: taskiq_redis/schedule_source.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, List, Optional
22

3-
from redis.asyncio import ConnectionPool, Redis
3+
from redis.asyncio import ConnectionPool, Redis, RedisCluster
44
from taskiq import ScheduleSource
55
from taskiq.abc.serializer import TaskiqSerializer
66
from taskiq.compat import model_dump, model_validate
@@ -95,3 +95,82 @@ async def post_send(self, task: ScheduledTask) -> None:
9595
async def shutdown(self) -> None:
9696
"""Shut down the schedule source."""
9797
await self.connection_pool.disconnect()
98+
99+
100+
class RedisClusterScheduleSource(ScheduleSource):
101+
"""
102+
Source of schedules for redis cluster.
103+
104+
This class allows you to store schedules in redis.
105+
Also it supports dynamic schedules.
106+
107+
:param url: url to redis cluster.
108+
:param prefix: prefix for redis schedule keys.
109+
:param buffer_size: buffer size for redis scan.
110+
This is how many keys will be fetched at once.
111+
:param max_connection_pool_size: maximum number of connections in pool.
112+
:param serializer: serializer for data.
113+
:param connection_kwargs: additional arguments for RedisCluster.
114+
"""
115+
116+
def __init__(
117+
self,
118+
url: str,
119+
prefix: str = "schedule",
120+
buffer_size: int = 50,
121+
serializer: Optional[TaskiqSerializer] = None,
122+
**connection_kwargs: Any,
123+
) -> None:
124+
self.prefix = prefix
125+
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
126+
url,
127+
**connection_kwargs,
128+
)
129+
self.buffer_size = buffer_size
130+
if serializer is None:
131+
serializer = PickleSerializer()
132+
self.serializer = serializer
133+
134+
async def delete_schedule(self, schedule_id: str) -> None:
135+
"""Remove schedule by id."""
136+
await self.redis.delete(f"{self.prefix}:{schedule_id}") # type: ignore[attr-defined]
137+
138+
async def add_schedule(self, schedule: ScheduledTask) -> None:
139+
"""
140+
Add schedule to redis.
141+
142+
:param schedule: schedule to add.
143+
:param schedule_id: schedule id.
144+
"""
145+
await self.redis.set( # type: ignore[attr-defined]
146+
f"{self.prefix}:{schedule.schedule_id}",
147+
self.serializer.dumpb(model_dump(schedule)),
148+
)
149+
150+
async def get_schedules(self) -> List[ScheduledTask]:
151+
"""
152+
Get all schedules from redis.
153+
154+
This method is used by scheduler to get all schedules.
155+
156+
:return: list of schedules.
157+
"""
158+
schedules = []
159+
buffer = []
160+
async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined]
161+
buffer.append(key)
162+
if len(buffer) >= self.buffer_size:
163+
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
164+
buffer = []
165+
if buffer:
166+
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
167+
return [
168+
model_validate(ScheduledTask, self.serializer.loadb(schedule))
169+
for schedule in schedules
170+
if schedule
171+
]
172+
173+
async def post_send(self, task: ScheduledTask) -> None:
174+
"""Delete a task after it's completed."""
175+
if task.time is not None:
176+
await self.delete_schedule(task.schedule_id)

Diff for: tests/test_schedule_source.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from taskiq import ScheduledTask
55

6-
from taskiq_redis import RedisScheduleSource
6+
from taskiq_redis import RedisClusterScheduleSource, RedisScheduleSource
77

88

99
@pytest.mark.anyio
@@ -59,3 +59,58 @@ async def test_post_run_cron(redis_url: str) -> None:
5959
schedules = await source.get_schedules()
6060
assert schedules == [schedule]
6161
await source.shutdown()
62+
63+
64+
@pytest.mark.anyio
65+
async def test_cluster_set_schedule(redis_cluster_url: str) -> None:
66+
prefix = uuid.uuid4().hex
67+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
68+
schedule = ScheduledTask(
69+
task_name="test_task",
70+
labels={},
71+
args=[],
72+
kwargs={},
73+
cron="* * * * *",
74+
)
75+
await source.add_schedule(schedule)
76+
schedules = await source.get_schedules()
77+
assert schedules == [schedule]
78+
await source.shutdown()
79+
80+
81+
@pytest.mark.anyio
82+
async def test_cluster_delete_schedule(redis_cluster_url: str) -> None:
83+
prefix = uuid.uuid4().hex
84+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
85+
schedule = ScheduledTask(
86+
task_name="test_task",
87+
labels={},
88+
args=[],
89+
kwargs={},
90+
cron="* * * * *",
91+
)
92+
await source.add_schedule(schedule)
93+
schedules = await source.get_schedules()
94+
assert schedules == [schedule]
95+
await source.delete_schedule(schedule.schedule_id)
96+
schedules = await source.get_schedules()
97+
# Schedules are empty.
98+
assert not schedules
99+
await source.shutdown()
100+
101+
102+
@pytest.mark.anyio
103+
async def test_cluster_post_run_cron(redis_cluster_url: str) -> None:
104+
prefix = uuid.uuid4().hex
105+
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
106+
schedule = ScheduledTask(
107+
task_name="test_task",
108+
labels={},
109+
args=[],
110+
kwargs={},
111+
cron="* * * * *",
112+
)
113+
await source.add_schedule(schedule)
114+
schedules = await source.get_schedules()
115+
assert schedules == [schedule]
116+
await source.shutdown()

0 commit comments

Comments
 (0)