Skip to content

Commit 8e35a10

Browse files
committed
Added redis schedule source.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent e717019 commit 8e35a10

9 files changed

+230
-36
lines changed

poetry.lock

+37-34
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ keywords = [
2626

2727
[tool.poetry.dependencies]
2828
python = "^3.8.1"
29-
taskiq = "^0"
29+
taskiq = ">=0.10.1,<1"
3030
redis = "^5"
3131

3232
[tool.poetry.group.dev.dependencies]
@@ -124,7 +124,7 @@ convention = "pep257"
124124
ignore-decorators = ["typing.overload"]
125125

126126
[tool.ruff.pylint]
127-
allow-magic-value-types = ["int", "str", "float", "tuple"]
127+
allow-magic-value-types = ["int", "str", "float"]
128128

129129
[tool.ruff.flake8-bugbear]
130130
extend-immutable-calls = ["taskiq_dependencies.Depends", "taskiq.TaskiqDepends"]

taskiq_redis/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Package for redis integration."""
22
from taskiq_redis.redis_backend import RedisAsyncResultBackend
33
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
4+
from taskiq_redis.schedule_source import RedisScheduleSource
45

56
__all__ = [
67
"RedisAsyncResultBackend",
78
"ListQueueBroker",
89
"PubSubBroker",
10+
"RedisScheduleSource",
911
]

taskiq_redis/schedule_source.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import dataclasses
2+
from typing import Any, List, Optional
3+
4+
from redis.asyncio import ConnectionPool, Redis
5+
from taskiq import ScheduleSource
6+
from taskiq.abc.serializer import TaskiqSerializer
7+
from taskiq.scheduler.scheduled_task import ScheduledTask
8+
9+
from taskiq_redis.serializer import PickleSerializer
10+
11+
12+
class RedisScheduleSource(ScheduleSource):
13+
"""
14+
Source of schedules for redis.
15+
16+
This class allows you to store schedules in redis.
17+
Also it supports dynamic schedules.
18+
19+
:param url: url to redis.
20+
:param prefix: prefix for redis schedule keys.
21+
:param buffer_size: buffer size for redis scan.
22+
This is how many keys will be fetched at once.
23+
:param max_connection_pool_size: maximum number of connections in pool.
24+
:param serializer: serializer for data.
25+
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
26+
"""
27+
28+
def __init__(
29+
self,
30+
url: str,
31+
prefix: str = "schedule",
32+
buffer_size: int = 50,
33+
max_connection_pool_size: Optional[int] = None,
34+
serializer: Optional[TaskiqSerializer] = None,
35+
**connection_kwargs: Any,
36+
) -> None:
37+
self.prefix = prefix
38+
self.connection_pool: ConnectionPool = ConnectionPool.from_url(
39+
url=url,
40+
max_connections=max_connection_pool_size,
41+
**connection_kwargs,
42+
)
43+
self.buffer_size = buffer_size
44+
if serializer is None:
45+
serializer = PickleSerializer()
46+
self.serializer = serializer
47+
48+
async def delete_schedule(self, schedule_id: str) -> None:
49+
"""Remove schedule by id."""
50+
async with Redis(connection_pool=self.connection_pool) as redis:
51+
await redis.delete(f"{self.prefix}:{schedule_id}")
52+
53+
async def add_schedule(self, schedule: ScheduledTask) -> None:
54+
"""
55+
Add schedule to redis.
56+
57+
:param schedule: schedule to add.
58+
:param schedule_id: schedule id.
59+
"""
60+
async with Redis(connection_pool=self.connection_pool) as redis:
61+
await redis.set(
62+
f"{self.prefix}:{schedule.schedule_id}",
63+
self.serializer.dumpb(dataclasses.asdict(schedule)),
64+
)
65+
66+
async def get_schedules(self) -> List[ScheduledTask]:
67+
"""
68+
Get all schedules from redis.
69+
70+
This method is used by scheduler to get all schedules.
71+
72+
:return: list of schedules.
73+
"""
74+
schedules = []
75+
async with Redis(connection_pool=self.connection_pool) as redis:
76+
buffer = []
77+
async for key in redis.scan_iter(f"{self.prefix}:*"):
78+
buffer.append(key)
79+
if len(buffer) >= self.buffer_size:
80+
schedules.extend(await redis.mget(buffer))
81+
buffer = []
82+
if buffer:
83+
schedules.extend(await redis.mget(buffer))
84+
return [
85+
ScheduledTask(**self.serializer.loadb(schedule))
86+
for schedule in schedules
87+
if schedule
88+
]
89+
90+
async def post_send(self, task: ScheduledTask) -> None:
91+
"""Delete a task after it's completed."""
92+
if task.time is not None:
93+
await self.delete_schedule(task.schedule_id)
94+
95+
async def shutdown(self) -> None:
96+
"""Shut down the schedule source."""
97+
await self.connection_pool.disconnect()

taskiq_redis/serializer.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pickle
2+
from typing import Any
3+
4+
from taskiq.abc.serializer import TaskiqSerializer
5+
6+
7+
class PickleSerializer(TaskiqSerializer):
8+
"""Serializer that uses pickle."""
9+
10+
def dumpb(self, value: Any) -> bytes:
11+
"""Dumps value to bytes."""
12+
return pickle.dumps(value)
13+
14+
def loadb(self, value: bytes) -> Any:
15+
"""Loads value from bytes."""
16+
return pickle.loads(value) # noqa: S301

tests/test_backend.py

+6
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ async def test_success_backend_default_result(
8686
result = await backend.get_result(task_id=task_id)
8787

8888
assert result == default_taskiq_result
89+
await backend.shutdown()
8990

9091

9192
@pytest.mark.anyio
@@ -117,6 +118,7 @@ async def test_success_backend_custom_result(
117118
assert result.is_err == custom_taskiq_result.is_err
118119
assert result.execution_time == custom_taskiq_result.execution_time
119120
assert result.log == custom_taskiq_result.log
121+
await backend.shutdown()
120122

121123

122124
@pytest.mark.anyio
@@ -183,6 +185,7 @@ async def test_success_backend_expire_ex_param(
183185
result = await backend.get_result(task_id=task_id)
184186

185187
assert result == default_taskiq_result
188+
await backend.shutdown()
186189

187190

188191
@pytest.mark.anyio
@@ -213,6 +216,7 @@ async def test_unsuccess_backend_expire_ex_param(
213216

214217
with pytest.raises(ResultIsMissingError):
215218
await backend.get_result(task_id=task_id)
219+
await backend.shutdown()
216220

217221

218222
@pytest.mark.anyio
@@ -243,6 +247,7 @@ async def test_success_backend_expire_px_param(
243247
result = await backend.get_result(task_id=task_id)
244248

245249
assert result == default_taskiq_result
250+
await backend.shutdown()
246251

247252

248253
@pytest.mark.anyio
@@ -273,3 +278,4 @@ async def test_unsuccess_backend_expire_px_param(
273278

274279
with pytest.raises(ResultIsMissingError):
275280
await backend.get_result(task_id=task_id)
281+
await backend.shutdown()

tests/test_broker.py

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ async def test_pub_sub_broker(
6868
message2 = worker2_task.result()
6969
assert message1 == valid_broker_message.message
7070
assert message1 == message2
71+
await broker.shutdown()
7172

7273

7374
@pytest.mark.anyio
@@ -92,3 +93,6 @@ async def test_list_queue_broker(
9293
assert worker1_task.done() != worker2_task.done()
9394
message = worker1_task.result() if worker1_task.done() else worker2_task.result()
9495
assert message == valid_broker_message.message
96+
worker1_task.cancel()
97+
worker2_task.cancel()
98+
await broker.shutdown()

tests/test_result_backend.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ async def test_set_result_success(redis_url: str) -> None:
3737
assert fetched_result.return_value == 11
3838
assert fetched_result.execution_time == 112.2
3939
assert fetched_result.is_err
40+
await result_backend.shutdown()
4041

4142

4243
@pytest.mark.anyio
@@ -69,6 +70,7 @@ async def test_fetch_without_logs(redis_url: str) -> None:
6970
assert fetched_result.return_value == 11
7071
assert fetched_result.execution_time == 112.2
7172
assert fetched_result.is_err
73+
await result_backend.shutdown()
7274

7375

7476
@pytest.mark.anyio
@@ -98,6 +100,8 @@ async def test_remove_results_after_reading(redis_url: str) -> None:
98100
with pytest.raises(ResultIsMissingError):
99101
await result_backend.get_result(task_id=task_id)
100102

103+
await result_backend.shutdown()
104+
101105

102106
@pytest.mark.anyio
103107
async def test_keep_results_after_reading(redis_url: str) -> None:
@@ -125,3 +129,4 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
125129
res1 = await result_backend.get_result(task_id=task_id)
126130
res2 = await result_backend.get_result(task_id=task_id)
127131
assert res1 == res2
132+
await result_backend.shutdown()

0 commit comments

Comments
 (0)