1
1
import asyncio
2
2
import inspect
3
+ from collections import defaultdict
3
4
from copy import copy
4
5
from logging import getLogger
5
- from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional
6
+ from typing import TYPE_CHECKING , Any , DefaultDict , Dict , Generator , List , Optional
6
7
7
8
from taskiq_dependencies .utils import ParamInfo
8
9
@@ -49,6 +50,8 @@ def traverse_deps( # noqa: C901, WPS210
49
50
# to separate dependencies that use cache,
50
51
# from dependencies that aren't.
51
52
cache = copy (self .initial_cache )
53
+ # Cache for all dependencies with kwargs.
54
+ kwargs_cache : "DefaultDict[Any, List[Any]]" = defaultdict (list )
52
55
# We iterate over topologicaly sorted list of dependencies.
53
56
for index , dep in enumerate (self .graph .ordered_deps ):
54
57
# If this dependency doesn't use cache,
@@ -62,6 +65,19 @@ def traverse_deps( # noqa: C901, WPS210
62
65
# If dependency is already calculated.
63
66
if dep .dependency in cache :
64
67
continue
68
+ # For dependencies with kwargs we check kwarged cache.
69
+ elif dep .kwargs and dep .dependency in kwargs_cache :
70
+ cache_hit = False
71
+ # We have to iterate over all cached dependencies with
72
+ # kwargs, because users may pass unhashable objects as kwargs.
73
+ # That's why we cannot use them as dict keys.
74
+ for cached_kwargs , _ in kwargs_cache [dep .dependency ]:
75
+ if cached_kwargs == dep .kwargs :
76
+ cache_hit = True
77
+ break
78
+ if cache_hit :
79
+ continue
80
+
65
81
kwargs = {}
66
82
# Now we get list of dependencies for current top-level dependency
67
83
# and iterate over it.
@@ -78,7 +94,13 @@ def traverse_deps( # noqa: C901, WPS210
78
94
if subdep .use_cache :
79
95
# If this dependency can be calculated, using cache,
80
96
# we try to get it from cache.
81
- kwargs [subdep .param_name ] = cache [subdep .dependency ]
97
+ if subdep .kwargs and subdep .dependency in kwargs_cache :
98
+ for cached_kwargs , kw_cache in kwargs_cache [subdep .dependency ]:
99
+ if cached_kwargs == subdep .kwargs :
100
+ kwargs [subdep .param_name ] = kw_cache
101
+ break
102
+ else :
103
+ kwargs [subdep .param_name ] = cache [subdep .dependency ]
82
104
else :
83
105
# If this dependency doesn't use cache,
84
106
# we resolve it's dependencies and
@@ -101,9 +123,13 @@ def traverse_deps( # noqa: C901, WPS210
101
123
# because we calculate them when needed.
102
124
and dep .dependency != ParamInfo
103
125
):
104
- user_kwargs = dep .kwargs
126
+ user_kwargs = copy ( dep .kwargs )
105
127
user_kwargs .update (kwargs )
106
- cache [dep .dependency ] = yield dep .dependency (** user_kwargs )
128
+ resolved = yield dep .dependency (** user_kwargs )
129
+ if dep .kwargs :
130
+ kwargs_cache [dep .dependency ].append ((dep .kwargs , resolved ))
131
+ else :
132
+ cache [dep .dependency ] = resolved
107
133
return kwargs
108
134
109
135
0 commit comments