diff --git a/docker-compose.yml b/docker-compose.yml index f7810f2..cce7408 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -58,3 +58,24 @@ services: REDIS_CLUSTER_CREATOR: "yes" ports: - 7001:6379 + + redis-master: + image: bitnami/redis:6.2.5 + environment: + ALLOW_EMPTY_PASSWORD: "yes" + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 5s + retries: 3 + start_period: 10s + + redis-sentinel: + image: bitnami/redis-sentinel:latest + depends_on: + - redis-master + environment: + ALLOW_EMPTY_PASSWORD: "yes" + REDIS_MASTER_HOST: "redis-master" + ports: + - 7002:26379 diff --git a/taskiq_redis/__init__.py b/taskiq_redis/__init__.py index b8262a1..0071c62 100644 --- a/taskiq_redis/__init__.py +++ b/taskiq_redis/__init__.py @@ -2,20 +2,30 @@ from taskiq_redis.redis_backend import ( RedisAsyncClusterResultBackend, RedisAsyncResultBackend, + RedisAsyncSentinelResultBackend, ) from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker +from taskiq_redis.redis_sentinel_broker import ( + ListQueueSentinelBroker, + PubSubSentinelBroker, +) from taskiq_redis.schedule_source import ( RedisClusterScheduleSource, RedisScheduleSource, + RedisSentinelScheduleSource, ) __all__ = [ "RedisAsyncClusterResultBackend", "RedisAsyncResultBackend", + "RedisAsyncSentinelResultBackend", "ListQueueBroker", "PubSubBroker", "ListQueueClusterBroker", + "ListQueueSentinelBroker", + "PubSubSentinelBroker", "RedisScheduleSource", "RedisClusterScheduleSource", + "RedisSentinelScheduleSource", ] diff --git a/taskiq_redis/redis_backend.py b/taskiq_redis/redis_backend.py index 026653c..dbd9e20 100644 --- a/taskiq_redis/redis_backend.py +++ b/taskiq_redis/redis_backend.py @@ -1,16 +1,40 @@ import pickle -from typing import Any, Dict, Optional, TypeVar, Union +import sys +from contextlib import asynccontextmanager +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, +) -from redis.asyncio import BlockingConnectionPool, Redis +from redis.asyncio import BlockingConnectionPool, Redis, Sentinel from redis.asyncio.cluster import RedisCluster from taskiq import AsyncResultBackend from taskiq.abc.result_backend import TaskiqResult +from taskiq.abc.serializer import TaskiqSerializer from taskiq_redis.exceptions import ( DuplicateExpireTimeSelectedError, ExpireTimeMustBeMoreThanZeroError, ResultIsMissingError, ) +from taskiq_redis.serializer import PickleSerializer + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +if TYPE_CHECKING: + _Redis: TypeAlias = Redis[bytes] +else: + _Redis: TypeAlias = Redis _ReturnType = TypeVar("_ReturnType") @@ -267,3 +291,142 @@ async def get_result( taskiq_result.log = None return taskiq_result + + +class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]): + """Async result based on redis sentinel.""" + + def __init__( + self, + sentinels: List[Tuple[str, int]], + master_name: str, + keep_results: bool = True, + result_ex_time: Optional[int] = None, + result_px_time: Optional[int] = None, + min_other_sentinels: int = 0, + sentinel_kwargs: Optional[Any] = None, + serializer: Optional[TaskiqSerializer] = None, + **connection_kwargs: Any, + ) -> None: + """ + Constructs a new result backend. + + :param sentinels: list of sentinel host and ports pairs. + :param master_name: sentinel master name. + :param keep_results: flag to not remove results from Redis after reading. + :param result_ex_time: expire time in seconds for result. + :param result_px_time: expire time in milliseconds for result. + :param max_connection_pool_size: maximum number of connections in pool. + :param connection_kwargs: additional arguments for redis BlockingConnectionPool. + + :raises DuplicateExpireTimeSelectedError: if result_ex_time + and result_px_time are selected. + :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time + and result_px_time are equal zero. + """ + self.sentinel = Sentinel( + sentinels=sentinels, + min_other_sentinels=min_other_sentinels, + sentinel_kwargs=sentinel_kwargs, + **connection_kwargs, + ) + self.master_name = master_name + if serializer is None: + serializer = PickleSerializer() + self.serializer = serializer + self.keep_results = keep_results + self.result_ex_time = result_ex_time + self.result_px_time = result_px_time + + unavailable_conditions = any( + ( + self.result_ex_time is not None and self.result_ex_time <= 0, + self.result_px_time is not None and self.result_px_time <= 0, + ), + ) + if unavailable_conditions: + raise ExpireTimeMustBeMoreThanZeroError( + "You must select one expire time param and it must be more than zero.", + ) + + if self.result_ex_time and self.result_px_time: + raise DuplicateExpireTimeSelectedError( + "Choose either result_ex_time or result_px_time.", + ) + + @asynccontextmanager + async def _acquire_master_conn(self) -> AsyncIterator[_Redis]: + async with self.sentinel.master_for(self.master_name) as redis_conn: + yield redis_conn + + async def set_result( + self, + task_id: str, + result: TaskiqResult[_ReturnType], + ) -> None: + """ + Sets task result in redis. + + Dumps TaskiqResult instance into the bytes and writes + it to redis. + + :param task_id: ID of the task. + :param result: TaskiqResult instance. + """ + redis_set_params: Dict[str, Union[str, bytes, int]] = { + "name": task_id, + "value": self.serializer.dumpb(result), + } + if self.result_ex_time: + redis_set_params["ex"] = self.result_ex_time + elif self.result_px_time: + redis_set_params["px"] = self.result_px_time + + async with self._acquire_master_conn() as redis: + await redis.set(**redis_set_params) # type: ignore + + async def is_result_ready(self, task_id: str) -> bool: + """ + Returns whether the result is ready. + + :param task_id: ID of the task. + + :returns: True if the result is ready else False. + """ + async with self._acquire_master_conn() as redis: + return bool(await redis.exists(task_id)) + + async def get_result( + self, + task_id: str, + with_logs: bool = False, + ) -> TaskiqResult[_ReturnType]: + """ + Gets result from the task. + + :param task_id: task's id. + :param with_logs: if True it will download task's logs. + :raises ResultIsMissingError: if there is no result when trying to get it. + :return: task's return value. + """ + async with self._acquire_master_conn() as redis: + if self.keep_results: + result_value = await redis.get( + name=task_id, + ) + else: + result_value = await redis.getdel( + name=task_id, + ) + + if result_value is None: + raise ResultIsMissingError + + taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301 + result_value, + ) + + if not with_logs: + taskiq_result.log = None + + return taskiq_result diff --git a/taskiq_redis/redis_sentinel_broker.py b/taskiq_redis/redis_sentinel_broker.py new file mode 100644 index 0000000..8d0bef8 --- /dev/null +++ b/taskiq_redis/redis_sentinel_broker.py @@ -0,0 +1,132 @@ +import sys +from contextlib import asynccontextmanager +from logging import getLogger +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterator, + Callable, + List, + Optional, + Tuple, + TypeVar, +) + +from redis.asyncio import Redis, Sentinel +from taskiq import AsyncResultBackend, BrokerMessage +from taskiq.abc.broker import AsyncBroker + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +if TYPE_CHECKING: + _Redis: TypeAlias = Redis[bytes] +else: + _Redis: TypeAlias = Redis + +_T = TypeVar("_T") + +logger = getLogger("taskiq.redis_sentinel_broker") + + +class BaseSentinelBroker(AsyncBroker): + """Base broker that works with Sentinel.""" + + def __init__( + self, + sentinels: List[Tuple[str, int]], + master_name: str, + result_backend: Optional[AsyncResultBackend[_T]] = None, + task_id_generator: Optional[Callable[[], str]] = None, + queue_name: str = "taskiq", + min_other_sentinels: int = 0, + sentinel_kwargs: Optional[Any] = None, + **connection_kwargs: Any, + ) -> None: + super().__init__( + result_backend=result_backend, + task_id_generator=task_id_generator, + ) + + self.sentinel = Sentinel( + sentinels=sentinels, + min_other_sentinels=min_other_sentinels, + sentinel_kwargs=sentinel_kwargs, + **connection_kwargs, + ) + self.master_name = master_name + self.queue_name = queue_name + + @asynccontextmanager + async def _acquire_master_conn(self) -> AsyncIterator[_Redis]: + async with self.sentinel.master_for(self.master_name) as redis_conn: + yield redis_conn + + +class PubSubSentinelBroker(BaseSentinelBroker): + """Broker that works with Sentinel and broadcasts tasks to all workers.""" + + async def kick(self, message: BrokerMessage) -> None: + """ + Publish message over PUBSUB channel. + + :param message: message to send. + """ + queue_name = message.labels.get("queue_name") or self.queue_name + async with self._acquire_master_conn() as redis_conn: + await redis_conn.publish(queue_name, message.message) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """ + Listen redis queue for new messages. + + This function listens to the pubsub channel + and yields all messages with proper types. + + :yields: broker messages. + """ + async with self._acquire_master_conn() as redis_conn: + redis_pubsub_channel = redis_conn.pubsub() + await redis_pubsub_channel.subscribe(self.queue_name) + async for message in redis_pubsub_channel.listen(): + if not message: + continue + if message["type"] != "message": + logger.debug("Received non-message from redis: %s", message) + continue + yield message["data"] + + +class ListQueueSentinelBroker(BaseSentinelBroker): + """Broker that works with Sentinel and distributes tasks between workers.""" + + async def kick(self, message: BrokerMessage) -> None: + """ + Put a message in a list. + + This method appends a message to the list of all messages. + + :param message: message to append. + """ + queue_name = message.labels.get("queue_name") or self.queue_name + async with self._acquire_master_conn() as redis_conn: + await redis_conn.lpush(queue_name, message.message) + + async def listen(self) -> AsyncGenerator[bytes, None]: + """ + Listen redis queue for new messages. + + This function listens to the queue + and yields new messages if they have BrokerMessage type. + + :yields: broker messages. + """ + redis_brpop_data_position = 1 + async with self._acquire_master_conn() as redis_conn: + while True: + yield (await redis_conn.brpop(self.queue_name))[ + redis_brpop_data_position + ] diff --git a/taskiq_redis/schedule_source.py b/taskiq_redis/schedule_source.py index fd3d922..6043e1f 100644 --- a/taskiq_redis/schedule_source.py +++ b/taskiq_redis/schedule_source.py @@ -1,6 +1,14 @@ -from typing import Any, List, Optional +import sys +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any, AsyncIterator, List, Optional, Tuple -from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis, RedisCluster +from redis.asyncio import ( + BlockingConnectionPool, + ConnectionPool, + Redis, + RedisCluster, + Sentinel, +) from taskiq import ScheduleSource from taskiq.abc.serializer import TaskiqSerializer from taskiq.compat import model_dump, model_validate @@ -8,6 +16,16 @@ from taskiq_redis.serializer import PickleSerializer +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +if TYPE_CHECKING: + _Redis: TypeAlias = Redis[bytes] +else: + _Redis: TypeAlias = Redis + class RedisScheduleSource(ScheduleSource): """ @@ -167,3 +185,97 @@ async def post_send(self, task: ScheduledTask) -> None: """Delete a task after it's completed.""" if task.time is not None: await self.delete_schedule(task.schedule_id) + + +class RedisSentinelScheduleSource(ScheduleSource): + """ + Source of schedules for redis cluster. + + This class allows you to store schedules in redis. + Also it supports dynamic schedules. + + :param sentinels: list of sentinel host and ports pairs. + :param master_name: sentinel master name. + :param prefix: prefix for redis schedule keys. + :param buffer_size: buffer size for redis scan. + This is how many keys will be fetched at once. + :param max_connection_pool_size: maximum number of connections in pool. + :param serializer: serializer for data. + :param connection_kwargs: additional arguments for RedisCluster. + """ + + def __init__( + self, + sentinels: List[Tuple[str, int]], + master_name: str, + prefix: str = "schedule", + buffer_size: int = 50, + serializer: Optional[TaskiqSerializer] = None, + min_other_sentinels: int = 0, + sentinel_kwargs: Optional[Any] = None, + **connection_kwargs: Any, + ) -> None: + self.prefix = prefix + self.sentinel = Sentinel( + sentinels=sentinels, + min_other_sentinels=min_other_sentinels, + sentinel_kwargs=sentinel_kwargs, + **connection_kwargs, + ) + self.master_name = master_name + self.buffer_size = buffer_size + if serializer is None: + serializer = PickleSerializer() + self.serializer = serializer + + @asynccontextmanager + async def _acquire_master_conn(self) -> AsyncIterator[_Redis]: + async with self.sentinel.master_for(self.master_name) as redis_conn: + yield redis_conn + + async def delete_schedule(self, schedule_id: str) -> None: + """Remove schedule by id.""" + async with self._acquire_master_conn() as redis: + await redis.delete(f"{self.prefix}:{schedule_id}") + + async def add_schedule(self, schedule: ScheduledTask) -> None: + """ + Add schedule to redis. + + :param schedule: schedule to add. + :param schedule_id: schedule id. + """ + async with self._acquire_master_conn() as redis: + await redis.set( + f"{self.prefix}:{schedule.schedule_id}", + self.serializer.dumpb(model_dump(schedule)), + ) + + async def get_schedules(self) -> List[ScheduledTask]: + """ + Get all schedules from redis. + + This method is used by scheduler to get all schedules. + + :return: list of schedules. + """ + schedules = [] + async with self._acquire_master_conn() as redis: + buffer = [] + async for key in redis.scan_iter(f"{self.prefix}:*"): + buffer.append(key) + if len(buffer) >= self.buffer_size: + schedules.extend(await redis.mget(buffer)) + buffer = [] + if buffer: + schedules.extend(await redis.mget(buffer)) + return [ + model_validate(ScheduledTask, self.serializer.loadb(schedule)) + for schedule in schedules + if schedule + ] + + async def post_send(self, task: ScheduledTask) -> None: + """Delete a task after it's completed.""" + if task.time is not None: + await self.delete_schedule(task.schedule_id) diff --git a/tests/conftest.py b/tests/conftest.py index dcccb79..1d000b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from typing import List, Tuple import pytest @@ -40,3 +41,33 @@ def redis_cluster_url() -> str: :return: URL string. """ return os.environ.get("TEST_REDIS_CLUSTER_URL", "redis://localhost:7001") + + +@pytest.fixture +def redis_sentinels() -> List[Tuple[str, int]]: + """ + List of redis sentinel hosts. + + It tries to get it from environ, + and return default one if the variable is + not set. + + :return: list of host and port pairs. + """ + sentinels = os.environ.get("TEST_REDIS_SENTINELS", "localhost:7002") + host, _, port = sentinels.partition(":") + return [(host, int(port))] + + +@pytest.fixture +def redis_sentinel_master_name() -> str: + """ + Redis sentinel master name. + + It tries to get it from environ, + and return default one if the variable is + not set. + + :return: redis sentinel master name string. + """ + return os.environ.get("TEST_REDIS_SENTINEL_MASTER_NAME", "mymaster") diff --git a/tests/test_broker.py b/tests/test_broker.py index 08f5dff..6fb0be8 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -1,11 +1,17 @@ import asyncio import uuid -from typing import Union +from typing import List, Tuple, Union import pytest from taskiq import AckableMessage, AsyncBroker, BrokerMessage -from taskiq_redis import ListQueueBroker, ListQueueClusterBroker, PubSubBroker +from taskiq_redis import ( + ListQueueBroker, + ListQueueClusterBroker, + ListQueueSentinelBroker, + PubSubBroker, + PubSubSentinelBroker, +) def test_no_url_should_raise_typeerror() -> None: @@ -169,3 +175,63 @@ async def test_list_queue_cluster_broker( assert worker_task.result() == valid_broker_message.message worker_task.cancel() await broker.shutdown() + + +@pytest.mark.anyio +async def test_pub_sub_sentinel_broker( + valid_broker_message: BrokerMessage, + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Test that messages are published and read correctly by PubSubSentinelBroker. + + We create two workers that listen and send a message to them. + Expect both workers to receive the same message we sent. + """ + broker = PubSubSentinelBroker( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + queue_name=uuid.uuid4().hex, + ) + worker1_task = asyncio.create_task(get_message(broker)) + worker2_task = asyncio.create_task(get_message(broker)) + await asyncio.sleep(0.3) + + await broker.kick(valid_broker_message) + await asyncio.sleep(0.3) + + message1 = worker1_task.result() + message2 = worker2_task.result() + assert message1 == valid_broker_message.message + assert message1 == message2 + await broker.shutdown() + + +@pytest.mark.anyio +async def test_list_queue_sentinel_broker( + valid_broker_message: BrokerMessage, + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Test that messages are published and read correctly by ListQueueSentinelBroker. + + We create two workers that listen and send a message to them. + Expect only one worker to receive the same message we sent. + """ + broker = ListQueueSentinelBroker( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + queue_name=uuid.uuid4().hex, + ) + worker_task = asyncio.create_task(get_message(broker)) + await asyncio.sleep(0.3) + + await broker.kick(valid_broker_message) + await asyncio.sleep(0.3) + + assert worker_task.done() + assert worker_task.result() == valid_broker_message.message + worker_task.cancel() + await broker.shutdown() diff --git a/tests/test_result_backend.py b/tests/test_result_backend.py index d85b28b..68ed965 100644 --- a/tests/test_result_backend.py +++ b/tests/test_result_backend.py @@ -1,10 +1,15 @@ import asyncio import uuid +from typing import List, Tuple import pytest from taskiq import TaskiqResult -from taskiq_redis import RedisAsyncClusterResultBackend, RedisAsyncResultBackend +from taskiq_redis import ( + RedisAsyncClusterResultBackend, + RedisAsyncResultBackend, + RedisAsyncSentinelResultBackend, +) from taskiq_redis.exceptions import ResultIsMissingError @@ -288,3 +293,148 @@ async def test_keep_results_after_reading_cluster(redis_cluster_url: str) -> Non res2 = await result_backend.get_result(task_id=task_id) assert res1 == res2 await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_set_result_success_sentinel( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Tests that results can be set without errors in cluster mode. + + :param redis_sentinels: list of host and port pairs. + :param redis_sentinel_master_name: redis sentinel master name string. + """ + result_backend = RedisAsyncSentinelResultBackend( # type: ignore + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + ) + task_id = uuid.uuid4().hex + result: "TaskiqResult[int]" = TaskiqResult( + is_err=True, + log="My Log", + return_value=11, + execution_time=112.2, + ) + await result_backend.set_result( + task_id=task_id, + result=result, + ) + + fetched_result = await result_backend.get_result( + task_id=task_id, + with_logs=True, + ) + assert fetched_result.log == "My Log" + assert fetched_result.return_value == 11 + assert fetched_result.execution_time == 112.2 + assert fetched_result.is_err + await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_fetch_without_logs_sentinel( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Check if fetching value without logs works fine. + + :param redis_sentinels: list of host and port pairs. + :param redis_sentinel_master_name: redis sentinel master name string. + """ + result_backend = RedisAsyncSentinelResultBackend( # type: ignore + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + ) + task_id = uuid.uuid4().hex + result: "TaskiqResult[int]" = TaskiqResult( + is_err=True, + log="My Log", + return_value=11, + execution_time=112.2, + ) + await result_backend.set_result( + task_id=task_id, + result=result, + ) + + fetched_result = await result_backend.get_result( + task_id=task_id, + with_logs=False, + ) + assert fetched_result.log is None + assert fetched_result.return_value == 11 + assert fetched_result.execution_time == 112.2 + assert fetched_result.is_err + await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_remove_results_after_reading_sentinel( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Check if removing results after reading works fine. + + :param redis_sentinels: list of host and port pairs. + :param redis_sentinel_master_name: redis sentinel master name string. + """ + result_backend = RedisAsyncSentinelResultBackend( # type: ignore + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + keep_results=False, + ) + task_id = uuid.uuid4().hex + result: "TaskiqResult[int]" = TaskiqResult( + is_err=True, + log="My Log", + return_value=11, + execution_time=112.2, + ) + await result_backend.set_result( + task_id=task_id, + result=result, + ) + + await result_backend.get_result(task_id=task_id) + with pytest.raises(ResultIsMissingError): + await result_backend.get_result(task_id=task_id) + + await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_keep_results_after_reading_sentinel( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Check if keeping results after reading works fine. + + :param redis_sentinels: list of host and port pairs. + :param redis_sentinel_master_name: redis sentinel master name string. + """ + result_backend = RedisAsyncSentinelResultBackend( # type: ignore + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + keep_results=True, + ) + task_id = uuid.uuid4().hex + result: "TaskiqResult[int]" = TaskiqResult( + is_err=True, + log="My Log", + return_value=11, + execution_time=112.2, + ) + await result_backend.set_result( + task_id=task_id, + result=result, + ) + + res1 = await result_backend.get_result(task_id=task_id) + res2 = await result_backend.get_result(task_id=task_id) + assert res1 == res2 + await result_backend.shutdown() diff --git a/tests/test_schedule_source.py b/tests/test_schedule_source.py index fafb17b..18108b2 100644 --- a/tests/test_schedule_source.py +++ b/tests/test_schedule_source.py @@ -1,11 +1,16 @@ import asyncio import datetime as dt import uuid +from typing import List, Tuple import pytest from taskiq import ScheduledTask -from taskiq_redis import RedisClusterScheduleSource, RedisScheduleSource +from taskiq_redis import ( + RedisClusterScheduleSource, + RedisScheduleSource, + RedisSentinelScheduleSource, +) @pytest.mark.anyio @@ -234,3 +239,140 @@ async def test_cluster_get_schedules(redis_cluster_url: str) -> None: assert schedule1 in schedules assert schedule2 in schedules await source.shutdown() + + +@pytest.mark.anyio +async def test_sentinel_set_schedule( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + prefix = uuid.uuid4().hex + source = RedisSentinelScheduleSource( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + prefix=prefix, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule) + schedules = await source.get_schedules() + assert schedules == [schedule] + await source.shutdown() + + +@pytest.mark.anyio +async def test_sentinel_delete_schedule( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + prefix = uuid.uuid4().hex + source = RedisSentinelScheduleSource( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + prefix=prefix, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule) + schedules = await source.get_schedules() + assert schedules == [schedule] + await source.delete_schedule(schedule.schedule_id) + schedules = await source.get_schedules() + # Schedules are empty. + assert not schedules + await source.shutdown() + + +@pytest.mark.anyio +async def test_sentinel_post_run_cron( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + prefix = uuid.uuid4().hex + source = RedisSentinelScheduleSource( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + prefix=prefix, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule) + assert await source.get_schedules() == [schedule] + await source.post_send(schedule) + assert await source.get_schedules() == [schedule] + await source.shutdown() + + +@pytest.mark.anyio +async def test_sentinel_post_run_time( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + prefix = uuid.uuid4().hex + source = RedisSentinelScheduleSource( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + prefix=prefix, + ) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=dt.datetime(2000, 1, 1), + ) + await source.add_schedule(schedule) + assert await source.get_schedules() == [schedule] + await source.post_send(schedule) + assert await source.get_schedules() == [] + await source.shutdown() + + +@pytest.mark.anyio +async def test_sentinel_buffer( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + prefix = uuid.uuid4().hex + source = RedisSentinelScheduleSource( + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + prefix=prefix, + buffer_size=1, + ) + schedule1 = ScheduledTask( + task_name="test_task1", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + schedule2 = ScheduledTask( + task_name="test_task2", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule1) + await source.add_schedule(schedule2) + schedules = await source.get_schedules() + assert len(schedules) == 2 + assert schedule1 in schedules + assert schedule2 in schedules + await source.shutdown()