Skip to content

Added redis schedule source. #45

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 37 additions & 34 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ keywords = [

[tool.poetry.dependencies]
python = "^3.8.1"
taskiq = "^0"
taskiq = ">=0.10.1,<1"
redis = "^5"

[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -124,7 +124,7 @@ convention = "pep257"
ignore-decorators = ["typing.overload"]

[tool.ruff.pylint]
allow-magic-value-types = ["int", "str", "float", "tuple"]
allow-magic-value-types = ["int", "str", "float"]

[tool.ruff.flake8-bugbear]
extend-immutable-calls = ["taskiq_dependencies.Depends", "taskiq.TaskiqDepends"]
2 changes: 2 additions & 0 deletions taskiq_redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Package for redis integration."""
from taskiq_redis.redis_backend import RedisAsyncResultBackend
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
from taskiq_redis.schedule_source import RedisScheduleSource

__all__ = [
"RedisAsyncResultBackend",
"ListQueueBroker",
"PubSubBroker",
"RedisScheduleSource",
]
97 changes: 97 additions & 0 deletions taskiq_redis/schedule_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import dataclasses
from typing import Any, List, Optional

from redis.asyncio import ConnectionPool, Redis
from taskiq import ScheduleSource
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.scheduler.scheduled_task import ScheduledTask

from taskiq_redis.serializer import PickleSerializer


class RedisScheduleSource(ScheduleSource):
"""
Source of schedules for redis.

This class allows you to store schedules in redis.
Also it supports dynamic schedules.

:param url: url to redis.
:param prefix: prefix for redis schedule keys.
:param buffer_size: buffer size for redis scan.
This is how many keys will be fetched at once.
:param max_connection_pool_size: maximum number of connections in pool.
:param serializer: serializer for data.
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
"""

def __init__(
self,
url: str,
prefix: str = "schedule",
buffer_size: int = 50,
max_connection_pool_size: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
self.prefix = prefix
self.connection_pool: ConnectionPool = ConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)
self.buffer_size = buffer_size
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer

async def delete_schedule(self, schedule_id: str) -> None:
"""Remove schedule by id."""
async with Redis(connection_pool=self.connection_pool) as redis:
await redis.delete(f"{self.prefix}:{schedule_id}")

async def add_schedule(self, schedule: ScheduledTask) -> None:
"""
Add schedule to redis.

:param schedule: schedule to add.
:param schedule_id: schedule id.
"""
async with Redis(connection_pool=self.connection_pool) as redis:
await redis.set(
f"{self.prefix}:{schedule.schedule_id}",
self.serializer.dumpb(dataclasses.asdict(schedule)),
)

async def get_schedules(self) -> List[ScheduledTask]:
"""
Get all schedules from redis.

This method is used by scheduler to get all schedules.

:return: list of schedules.
"""
schedules = []
async with Redis(connection_pool=self.connection_pool) as redis:
buffer = []
async for key in redis.scan_iter(f"{self.prefix}:*"):
buffer.append(key)
if len(buffer) >= self.buffer_size:
schedules.extend(await redis.mget(buffer))
buffer = []
if buffer:
schedules.extend(await redis.mget(buffer))
return [
ScheduledTask(**self.serializer.loadb(schedule))
for schedule in schedules
if schedule
]

async def post_send(self, task: ScheduledTask) -> None:
"""Delete a task after it's completed."""
if task.time is not None:
await self.delete_schedule(task.schedule_id)

async def shutdown(self) -> None:
"""Shut down the schedule source."""
await self.connection_pool.disconnect()
16 changes: 16 additions & 0 deletions taskiq_redis/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pickle
from typing import Any

from taskiq.abc.serializer import TaskiqSerializer


class PickleSerializer(TaskiqSerializer):
"""Serializer that uses pickle."""

def dumpb(self, value: Any) -> bytes:
"""Dumps value to bytes."""
return pickle.dumps(value)

def loadb(self, value: bytes) -> Any:
"""Loads value from bytes."""
return pickle.loads(value) # noqa: S301
6 changes: 6 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def test_success_backend_default_result(
result = await backend.get_result(task_id=task_id)

assert result == default_taskiq_result
await backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -117,6 +118,7 @@ async def test_success_backend_custom_result(
assert result.is_err == custom_taskiq_result.is_err
assert result.execution_time == custom_taskiq_result.execution_time
assert result.log == custom_taskiq_result.log
await backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -183,6 +185,7 @@ async def test_success_backend_expire_ex_param(
result = await backend.get_result(task_id=task_id)

assert result == default_taskiq_result
await backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -213,6 +216,7 @@ async def test_unsuccess_backend_expire_ex_param(

with pytest.raises(ResultIsMissingError):
await backend.get_result(task_id=task_id)
await backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -243,6 +247,7 @@ async def test_success_backend_expire_px_param(
result = await backend.get_result(task_id=task_id)

assert result == default_taskiq_result
await backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -273,3 +278,4 @@ async def test_unsuccess_backend_expire_px_param(

with pytest.raises(ResultIsMissingError):
await backend.get_result(task_id=task_id)
await backend.shutdown()
4 changes: 4 additions & 0 deletions tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def test_pub_sub_broker(
message2 = worker2_task.result()
assert message1 == valid_broker_message.message
assert message1 == message2
await broker.shutdown()


@pytest.mark.anyio
Expand All @@ -92,3 +93,6 @@ async def test_list_queue_broker(
assert worker1_task.done() != worker2_task.done()
message = worker1_task.result() if worker1_task.done() else worker2_task.result()
assert message == valid_broker_message.message
worker1_task.cancel()
worker2_task.cancel()
await broker.shutdown()
5 changes: 5 additions & 0 deletions tests/test_result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def test_set_result_success(redis_url: str) -> None:
assert fetched_result.return_value == 11
assert fetched_result.execution_time == 112.2
assert fetched_result.is_err
await result_backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -69,6 +70,7 @@ async def test_fetch_without_logs(redis_url: str) -> None:
assert fetched_result.return_value == 11
assert fetched_result.execution_time == 112.2
assert fetched_result.is_err
await result_backend.shutdown()


@pytest.mark.anyio
Expand Down Expand Up @@ -98,6 +100,8 @@ async def test_remove_results_after_reading(redis_url: str) -> None:
with pytest.raises(ResultIsMissingError):
await result_backend.get_result(task_id=task_id)

await result_backend.shutdown()


@pytest.mark.anyio
async def test_keep_results_after_reading(redis_url: str) -> None:
Expand Down Expand Up @@ -125,3 +129,4 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
res1 = await result_backend.get_result(task_id=task_id)
res2 = await result_backend.get_result(task_id=task_id)
assert res1 == res2
await result_backend.shutdown()
Loading