Skip to content

Commit 090a00a

Browse files
authored
add optional prefix to redis keys (#74)
1 parent e30ed08 commit 090a00a

File tree

1 file changed

+42
-18
lines changed

1 file changed

+42
-18
lines changed

Diff for: taskiq_redis/redis_backend.py

+42-18
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
result_px_time: Optional[int] = None,
5757
max_connection_pool_size: Optional[int] = None,
5858
serializer: Optional[TaskiqSerializer] = None,
59+
prefix_str: Optional[str] = None,
5960
**connection_kwargs: Any,
6061
) -> None:
6162
"""
@@ -82,6 +83,7 @@ def __init__(
8283
self.keep_results = keep_results
8384
self.result_ex_time = result_ex_time
8485
self.result_px_time = result_px_time
86+
self.prefix_str = prefix_str
8587

8688
unavailable_conditions = any(
8789
(
@@ -99,6 +101,11 @@ def __init__(
99101
"Choose either result_ex_time or result_px_time.",
100102
)
101103

104+
def _task_name(self, task_id: str) -> str:
105+
if self.prefix_str is None:
106+
return task_id
107+
return f"{self.prefix_str}:{task_id}"
108+
102109
async def shutdown(self) -> None:
103110
"""Closes redis connection."""
104111
await self.redis_pool.disconnect()
@@ -119,7 +126,7 @@ async def set_result(
119126
:param result: TaskiqResult instance.
120127
"""
121128
redis_set_params: Dict[str, Union[str, int, bytes]] = {
122-
"name": task_id,
129+
"name": self._task_name(task_id),
123130
"value": self.serializer.dumpb(model_dump(result)),
124131
}
125132
if self.result_ex_time:
@@ -139,7 +146,7 @@ async def is_result_ready(self, task_id: str) -> bool:
139146
:returns: True if the result is ready else False.
140147
"""
141148
async with Redis(connection_pool=self.redis_pool) as redis:
142-
return bool(await redis.exists(task_id))
149+
return bool(await redis.exists(self._task_name(task_id)))
143150

144151
async def get_result(
145152
self,
@@ -154,14 +161,15 @@ async def get_result(
154161
:raises ResultIsMissingError: if there is no result when trying to get it.
155162
:return: task's return value.
156163
"""
164+
task_name = self._task_name(task_id)
157165
async with Redis(connection_pool=self.redis_pool) as redis:
158166
if self.keep_results:
159167
result_value = await redis.get(
160-
name=task_id,
168+
name=task_name,
161169
)
162170
else:
163171
result_value = await redis.getdel(
164-
name=task_id,
172+
name=task_name,
165173
)
166174

167175
if result_value is None:
@@ -192,7 +200,7 @@ async def set_progress(
192200
:param result: task's TaskProgress instance.
193201
"""
194202
redis_set_params: Dict[str, Union[str, int, bytes]] = {
195-
"name": task_id + PROGRESS_KEY_SUFFIX,
203+
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
196204
"value": self.serializer.dumpb(model_dump(progress)),
197205
}
198206
if self.result_ex_time:
@@ -215,7 +223,7 @@ async def get_progress(
215223
"""
216224
async with Redis(connection_pool=self.redis_pool) as redis:
217225
result_value = await redis.get(
218-
name=task_id + PROGRESS_KEY_SUFFIX,
226+
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
219227
)
220228

221229
if result_value is None:
@@ -237,6 +245,7 @@ def __init__(
237245
result_ex_time: Optional[int] = None,
238246
result_px_time: Optional[int] = None,
239247
serializer: Optional[TaskiqSerializer] = None,
248+
prefix_str: Optional[str] = None,
240249
**connection_kwargs: Any,
241250
) -> None:
242251
"""
@@ -261,6 +270,7 @@ def __init__(
261270
self.keep_results = keep_results
262271
self.result_ex_time = result_ex_time
263272
self.result_px_time = result_px_time
273+
self.prefix_str = prefix_str
264274

265275
unavailable_conditions = any(
266276
(
@@ -278,6 +288,11 @@ def __init__(
278288
"Choose either result_ex_time or result_px_time.",
279289
)
280290

291+
def _task_name(self, task_id: str) -> str:
292+
if self.prefix_str is None:
293+
return task_id
294+
return f"{self.prefix_str}:{task_id}"
295+
281296
async def shutdown(self) -> None:
282297
"""Closes redis connection."""
283298
await self.redis.aclose() # type: ignore[attr-defined]
@@ -298,7 +313,7 @@ async def set_result(
298313
:param result: TaskiqResult instance.
299314
"""
300315
redis_set_params: Dict[str, Union[str, bytes, int]] = {
301-
"name": task_id,
316+
"name": self._task_name(task_id),
302317
"value": self.serializer.dumpb(model_dump(result)),
303318
}
304319
if self.result_ex_time:
@@ -316,7 +331,7 @@ async def is_result_ready(self, task_id: str) -> bool:
316331
317332
:returns: True if the result is ready else False.
318333
"""
319-
return bool(await self.redis.exists(task_id)) # type: ignore[attr-defined]
334+
return bool(await self.redis.exists(self._task_name(task_id))) # type: ignore[attr-defined]
320335

321336
async def get_result(
322337
self,
@@ -331,13 +346,14 @@ async def get_result(
331346
:raises ResultIsMissingError: if there is no result when trying to get it.
332347
:return: task's return value.
333348
"""
349+
task_name = self._task_name(task_id)
334350
if self.keep_results:
335351
result_value = await self.redis.get( # type: ignore[attr-defined]
336-
name=task_id,
352+
name=task_name,
337353
)
338354
else:
339355
result_value = await self.redis.getdel( # type: ignore[attr-defined]
340-
name=task_id,
356+
name=task_name,
341357
)
342358

343359
if result_value is None:
@@ -368,7 +384,7 @@ async def set_progress(
368384
:param result: task's TaskProgress instance.
369385
"""
370386
redis_set_params: Dict[str, Union[str, int, bytes]] = {
371-
"name": task_id + PROGRESS_KEY_SUFFIX,
387+
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
372388
"value": self.serializer.dumpb(model_dump(progress)),
373389
}
374390
if self.result_ex_time:
@@ -389,7 +405,7 @@ async def get_progress(
389405
:return: task's TaskProgress instance.
390406
"""
391407
result_value = await self.redis.get( # type: ignore[attr-defined]
392-
name=task_id + PROGRESS_KEY_SUFFIX,
408+
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
393409
)
394410

395411
if result_value is None:
@@ -414,6 +430,7 @@ def __init__(
414430
min_other_sentinels: int = 0,
415431
sentinel_kwargs: Optional[Any] = None,
416432
serializer: Optional[TaskiqSerializer] = None,
433+
prefix_str: Optional[str] = None,
417434
**connection_kwargs: Any,
418435
) -> None:
419436
"""
@@ -443,6 +460,7 @@ def __init__(
443460
self.keep_results = keep_results
444461
self.result_ex_time = result_ex_time
445462
self.result_px_time = result_px_time
463+
self.prefix_str = prefix_str
446464

447465
unavailable_conditions = any(
448466
(
@@ -460,6 +478,11 @@ def __init__(
460478
"Choose either result_ex_time or result_px_time.",
461479
)
462480

481+
def _task_name(self, task_id: str) -> str:
482+
if self.prefix_str is None:
483+
return task_id
484+
return f"{self.prefix_str}:{task_id}"
485+
463486
@asynccontextmanager
464487
async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
465488
async with self.sentinel.master_for(self.master_name) as redis_conn:
@@ -480,7 +503,7 @@ async def set_result(
480503
:param result: TaskiqResult instance.
481504
"""
482505
redis_set_params: Dict[str, Union[str, bytes, int]] = {
483-
"name": task_id,
506+
"name": self._task_name(task_id),
484507
"value": self.serializer.dumpb(model_dump(result)),
485508
}
486509
if self.result_ex_time:
@@ -500,7 +523,7 @@ async def is_result_ready(self, task_id: str) -> bool:
500523
:returns: True if the result is ready else False.
501524
"""
502525
async with self._acquire_master_conn() as redis:
503-
return bool(await redis.exists(task_id))
526+
return bool(await redis.exists(self._task_name(task_id)))
504527

505528
async def get_result(
506529
self,
@@ -515,14 +538,15 @@ async def get_result(
515538
:raises ResultIsMissingError: if there is no result when trying to get it.
516539
:return: task's return value.
517540
"""
541+
task_name = self._task_name(task_id)
518542
async with self._acquire_master_conn() as redis:
519543
if self.keep_results:
520544
result_value = await redis.get(
521-
name=task_id,
545+
name=task_name,
522546
)
523547
else:
524548
result_value = await redis.getdel(
525-
name=task_id,
549+
name=task_name,
526550
)
527551

528552
if result_value is None:
@@ -553,7 +577,7 @@ async def set_progress(
553577
:param result: task's TaskProgress instance.
554578
"""
555579
redis_set_params: Dict[str, Union[str, int, bytes]] = {
556-
"name": task_id + PROGRESS_KEY_SUFFIX,
580+
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
557581
"value": self.serializer.dumpb(model_dump(progress)),
558582
}
559583
if self.result_ex_time:
@@ -576,7 +600,7 @@ async def get_progress(
576600
"""
577601
async with self._acquire_master_conn() as redis:
578602
result_value = await redis.get(
579-
name=task_id + PROGRESS_KEY_SUFFIX,
603+
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
580604
)
581605

582606
if result_value is None:

0 commit comments

Comments
 (0)