Skip to content

Commit 6caa86e

Browse files
authored
Merge pull request #29 from taskiq-python/feature/graph-dep
2 parents e05a6d4 + 02acbd8 commit 6caa86e

File tree

5 files changed

+57
-6
lines changed

5 files changed

+57
-6
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ with graph.sync_ctx() as ctx:
127127
The ParamInfo has the information about name and parameters signature. It's useful if you want to create a dependency that changes based on parameter name, or signature.
128128

129129

130+
Also ParamInfo contains the initial graph that was used.
131+
130132
## Exception propagation
131133

132134
By default if error happens within the context, we send this error to the dependency,

taskiq_dependencies/ctx.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ class BaseResolveContext:
2020
def __init__(
2121
self,
2222
graph: "DependencyGraph",
23+
main_graph: "DependencyGraph",
2324
initial_cache: Optional[Dict[Any, Any]] = None,
2425
exception_propagation: bool = True,
2526
) -> None:
2627
self.graph = graph
28+
# Main graph that contains all the subgraphs.
29+
self.main_graph = main_graph
2730
self.opened_dependencies: List[Any] = []
2831
self.sub_contexts: "List[Any]" = []
2932
self.initial_cache = initial_cache or {}
@@ -89,7 +92,11 @@ def traverse_deps( # noqa: C901
8992
# If the user want to get ParamInfo,
9093
# we get declaration of the current dependency.
9194
if subdep.dependency == ParamInfo:
92-
kwargs[subdep.param_name] = ParamInfo(dep.param_name, dep.signature)
95+
kwargs[subdep.param_name] = ParamInfo(
96+
dep.param_name,
97+
self.main_graph,
98+
dep.signature,
99+
)
93100
continue
94101
if subdep.use_cache:
95102
# If this dependency can be calculated, using cache,
@@ -197,7 +204,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
197204
:return: dict with resolved kwargs.
198205
"""
199206
if getattr(executed_func, "dep_graph", False):
200-
ctx = SyncResolveContext(executed_func, initial_cache)
207+
ctx = SyncResolveContext(executed_func, self.main_graph, initial_cache)
201208
self.sub_contexts.append(ctx)
202209
sub_result = ctx.resolve_kwargs()
203210
elif inspect.isgenerator(executed_func):
@@ -325,7 +332,7 @@ async def resolver(
325332
:return: dict with resolved kwargs.
326333
"""
327334
if getattr(executed_func, "dep_graph", False):
328-
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
335+
ctx = AsyncResolveContext(executed_func, self.main_graph, initial_cache) # type: ignore
329336
self.sub_contexts.append(ctx)
330337
sub_result = await ctx.resolve_kwargs()
331338
elif inspect.isgenerator(executed_func):

taskiq_dependencies/graph.py

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from taskiq_dependencies.ctx import AsyncResolveContext, SyncResolveContext
99
from taskiq_dependencies.dependency import Dependency
10+
from taskiq_dependencies.utils import ParamInfo
1011

1112
try:
1213
from fastapi.params import Depends as FastapiDepends
@@ -63,6 +64,7 @@ def async_ctx(
6364
if replaced_deps:
6465
graph = DependencyGraph(self.target, replaced_deps)
6566
return AsyncResolveContext(
67+
graph,
6668
graph,
6769
initial_cache,
6870
exception_propagation,
@@ -89,6 +91,7 @@ def sync_ctx(
8991
if replaced_deps:
9092
graph = DependencyGraph(self.target, replaced_deps)
9193
return SyncResolveContext(
94+
graph,
9295
graph,
9396
initial_cache,
9497
exception_propagation,
@@ -122,8 +125,14 @@ def _build_graph(self) -> None: # noqa: C901
122125
continue
123126
if dep.dependency is None:
124127
continue
128+
# If we have replaced dependencies, we need to replace
129+
# them in the current dependency.
125130
if self.replaced_deps and dep.dependency in self.replaced_deps:
126131
dep.dependency = self.replaced_deps[dep.dependency]
132+
# We can say for sure that ParamInfo doesn't have any dependencies,
133+
# so we skip it.
134+
if dep.dependency == ParamInfo:
135+
continue
127136
# Get signature and type hints.
128137
origin = getattr(dep.dependency, "__origin__", None)
129138
if origin is None:

taskiq_dependencies/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import inspect
22
import sys
33
from contextlib import _AsyncGeneratorContextManager, _GeneratorContextManager
4-
from typing import Any, AsyncContextManager, ContextManager, Optional
4+
from typing import TYPE_CHECKING, Any, AsyncContextManager, ContextManager, Optional
55

66
if sys.version_info >= (3, 10):
77
from typing import TypeGuard
88
else:
99
from typing_extensions import TypeGuard
1010

11+
if TYPE_CHECKING:
12+
from taskiq_dependencies.graph import DependencyGraph
13+
1114

1215
class ParamInfo:
1316
"""
@@ -23,9 +26,11 @@ class ParamInfo:
2326
def __init__(
2427
self,
2528
name: str,
29+
graph: "DependencyGraph",
2630
signature: Optional[inspect.Parameter] = None,
2731
) -> None:
2832
self.name = name
33+
self.graph = graph
2934
self.definition = signature
3035

3136
def __repr__(self) -> str:

tests/test_graph.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,15 @@ def dep(info: ParamInfo = Depends()) -> ParamInfo:
334334
def target(my_test_param: ParamInfo = Depends(dep)) -> None:
335335
return None
336336

337-
with DependencyGraph(target=target).sync_ctx() as g:
337+
graph = DependencyGraph(target=target)
338+
with graph.sync_ctx() as g:
338339
kwargs = g.resolve_kwargs()
339340

340341
info: ParamInfo = kwargs["my_test_param"]
341342
assert info.name == "my_test_param"
342343
assert info.definition
343344
assert info.definition.annotation == ParamInfo
345+
assert info.graph == graph
344346

345347

346348
def test_param_info_no_dependant() -> None:
@@ -349,12 +351,14 @@ def test_param_info_no_dependant() -> None:
349351
def target(info: ParamInfo = Depends()) -> None:
350352
return None
351353

352-
with DependencyGraph(target=target).sync_ctx() as g:
354+
graph = DependencyGraph(target=target)
355+
with graph.sync_ctx() as g:
353356
kwargs = g.resolve_kwargs()
354357

355358
info: ParamInfo = kwargs["info"]
356359
assert info.name == ""
357360
assert info.definition is None
361+
assert info.graph == graph
358362

359363

360364
def test_class_based_dependencies() -> None:
@@ -863,3 +867,27 @@ def target(acm: TestACM = Depends(get_test_acm)) -> None:
863867
kwargs = await ctx.resolve_kwargs()
864868
assert kwargs["acm"] == test_acm
865869
assert not test_acm.opened
870+
871+
872+
def test_param_info_subgraph() -> None:
873+
"""
874+
Test subgraphs for ParamInfo.
875+
876+
Test that correct graph is stored in ParamInfo
877+
even if evaluated from subgraphs.
878+
"""
879+
880+
def inner_dep(info: ParamInfo = Depends()) -> ParamInfo:
881+
return info
882+
883+
def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
884+
return None
885+
886+
graph = DependencyGraph(target=target)
887+
with graph.sync_ctx() as g:
888+
kwargs = g.resolve_kwargs()
889+
890+
info: ParamInfo = kwargs["info"]
891+
assert info.name == ""
892+
assert info.definition is None
893+
assert info.graph == graph

0 commit comments

Comments
 (0)