Skip to content

Commit 5e06dce

Browse files
committed
Merge branch 'release/1.4.1'
2 parents b5de2e0 + cffa9b7 commit 5e06dce

File tree

4 files changed

+62
-8
lines changed

4 files changed

+62
-8
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "taskiq-dependencies"
3-
version = "1.4.0"
3+
version = "1.4.1"
44
description = "FastAPI like dependency injection implementation"
55
authors = ["Pavel Kirilin <win10@list.ru>"]
66
readme = "README.md"

taskiq_dependencies/ctx.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
import inspect
3+
from collections import defaultdict
34
from copy import copy
45
from logging import getLogger
5-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
6+
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Generator, List, Optional
67

78
from taskiq_dependencies.utils import ParamInfo
89

@@ -49,6 +50,8 @@ def traverse_deps( # noqa: C901, WPS210
4950
# to separate dependencies that use cache,
5051
# from dependencies that aren't.
5152
cache = copy(self.initial_cache)
53+
# Cache for all dependencies with kwargs.
54+
kwargs_cache: "DefaultDict[Any, List[Any]]" = defaultdict(list)
5255
# We iterate over topologicaly sorted list of dependencies.
5356
for index, dep in enumerate(self.graph.ordered_deps):
5457
# If this dependency doesn't use cache,
@@ -62,6 +65,19 @@ def traverse_deps( # noqa: C901, WPS210
6265
# If dependency is already calculated.
6366
if dep.dependency in cache:
6467
continue
68+
# For dependencies with kwargs we check kwarged cache.
69+
elif dep.kwargs and dep.dependency in kwargs_cache:
70+
cache_hit = False
71+
# We have to iterate over all cached dependencies with
72+
# kwargs, because users may pass unhashable objects as kwargs.
73+
# That's why we cannot use them as dict keys.
74+
for cached_kwargs, _ in kwargs_cache[dep.dependency]:
75+
if cached_kwargs == dep.kwargs:
76+
cache_hit = True
77+
break
78+
if cache_hit:
79+
continue
80+
6581
kwargs = {}
6682
# Now we get list of dependencies for current top-level dependency
6783
# and iterate over it.
@@ -78,7 +94,13 @@ def traverse_deps( # noqa: C901, WPS210
7894
if subdep.use_cache:
7995
# If this dependency can be calculated, using cache,
8096
# we try to get it from cache.
81-
kwargs[subdep.param_name] = cache[subdep.dependency]
97+
if subdep.kwargs and subdep.dependency in kwargs_cache:
98+
for cached_kwargs, kw_cache in kwargs_cache[subdep.dependency]:
99+
if cached_kwargs == subdep.kwargs:
100+
kwargs[subdep.param_name] = kw_cache
101+
break
102+
else:
103+
kwargs[subdep.param_name] = cache[subdep.dependency]
82104
else:
83105
# If this dependency doesn't use cache,
84106
# we resolve it's dependencies and
@@ -101,9 +123,13 @@ def traverse_deps( # noqa: C901, WPS210
101123
# because we calculate them when needed.
102124
and dep.dependency != ParamInfo
103125
):
104-
user_kwargs = dep.kwargs
126+
user_kwargs = copy(dep.kwargs)
105127
user_kwargs.update(kwargs)
106-
cache[dep.dependency] = yield dep.dependency(**user_kwargs)
128+
resolved = yield dep.dependency(**user_kwargs)
129+
if dep.kwargs:
130+
kwargs_cache[dep.dependency].append((dep.kwargs, resolved))
131+
else:
132+
cache[dep.dependency] = resolved
107133
return kwargs
108134

109135

taskiq_dependencies/graph.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
try:
1010
from fastapi.params import Depends as FastapiDepends # noqa: WPS433
1111
except ImportError:
12-
FastapiDepends = Dependency # type: ignore
12+
FastapiDepends = None
1313

1414

1515
class DependencyGraph:
@@ -183,7 +183,10 @@ def _build_graph(self) -> None: # noqa: C901, WPS210
183183

184184
# This is for FastAPI integration. So you can
185185
# use Depends from taskiq mixed with fastapi's dependencies.
186-
if isinstance(default_value, FastapiDepends):
186+
if FastapiDepends is not None and isinstance( # noqa: WPS337
187+
default_value,
188+
FastapiDepends,
189+
):
187190
default_value = Dependency(
188191
dependency=default_value.dependency,
189192
use_cache=default_value.use_cache,
@@ -194,7 +197,6 @@ def _build_graph(self) -> None: # noqa: C901, WPS210
194197
# TaskiqDepends.
195198
if not isinstance(default_value, Dependency):
196199
continue
197-
198200
# If user haven't set the dependency,
199201
# using TaskiqDepends constructor,
200202
# we need to find variable's type hint.

tests/test_graph.py

+26
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,29 @@ def target(val: int = Depends(dep)) -> None:
704704
) as ctx:
705705
kwargs = await ctx.resolve_kwargs()
706706
assert kwargs["val"] == 321
707+
708+
709+
def test_kwargs_caches() -> None:
710+
"""
711+
Test that kwarged caches work.
712+
713+
If user wants to pass kwargs to the dependency
714+
multiple times, we must verify that it works.
715+
716+
And dependency calculated multiple times,
717+
even with caches.
718+
"""
719+
720+
def random_dep(a: int) -> int:
721+
return a
722+
723+
A = Depends(random_dep, kwargs={"a": 1})
724+
B = Depends(random_dep, kwargs={"a": 2})
725+
726+
def target(a: int = A, b: int = B) -> int:
727+
return a + b
728+
729+
graph = DependencyGraph(target=target)
730+
with graph.sync_ctx() as ctx:
731+
kwargs = ctx.resolve_kwargs()
732+
assert target(**kwargs) == 3

0 commit comments

Comments
 (0)