diff --git a/README.md b/README.md index 2abe9db..7b52334 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,21 @@ combined = await FetchWithParams().both( ) print(combined) ``` +### Parameters with default values are ignored + +You can opt a parameter out of the dependency injection mechanism by assigning it a default value: + +```python +class IgnoreDefaultParameters(AsyncInjectAll): + async def go(self, calc1, x=5): + return calc1 + x + + async def calc1(self): + return 5 + +print(await IgnoreDefaultParameters().go()) +# Prints 10 +``` ### AsyncInject and @inject diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py index e3fe841..3cac54e 100644 --- a/asyncinject/__init__.py +++ b/asyncinject/__init__.py @@ -48,12 +48,21 @@ def __new__(cls, name, bases, attrs): def _make_method(method): - parameters = inspect.signature(method).parameters.keys() + parameters = inspect.signature(method).parameters @wraps(method) async def inner(self, **kwargs): # Any parameters not provided by kwargs are resolved from registry - to_resolve = [p for p in parameters if p not in kwargs and p != "self"] + to_resolve = [ + p + for p in parameters + # Not already provided + if p not in kwargs + # Not self + and p != "self" + # Doesn't have a default value + and parameters[p].default is inspect._empty + ] missing = [p for p in to_resolve if p not in self._registry] assert ( not missing diff --git a/tests/test_asyncinject.py b/tests/test_asyncinject.py index e64af99..8c2740e 100644 --- a/tests/test_asyncinject.py +++ b/tests/test_asyncinject.py @@ -66,6 +66,14 @@ async def calc2(self, param1): return 6 + param1 +class IgnoreDefaultParameters(AsyncInjectAll): + async def go(self, calc1, x=5): + return calc1 + x + + async def calc1(self): + return 5 + + @pytest.mark.asyncio async def test_simple(): assert await Simple().one() == ["two", "one"] @@ -103,6 +111,12 @@ async def test_parameters_passed_through(): assert result == 12 +@pytest.mark.asyncio +async def test_ignore_default_parameters(): + result = await IgnoreDefaultParameters().go() + assert result == 10 + + @pytest.mark.asyncio async def test_resolve(): object = WithParameters()