diff --git a/README.md b/README.md index 4c8f80f..7fd9f6a 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,24 @@ The HTTP requests to `www.example.com` and `simonwillison.net` will be performed The library notices that `both()` takes two arguments which are the names of other registered `async def` functions, and will construct an execution plan that executes those two functions in parallel, then passes their results to the `both()` method. +### Registering additional functions + +In addition to registering functions by passing them to the constructor, you can also add them to a registry using the `.register()` method: + +```python +async def another(): + return "another" + +registry.register(another) +``` +To register them with a name other than the name of the function, pass the `name=` argument: +```python +async def another(): + return "another 2" + +registry.register(another, name="another_2") +``` + ### Resolving an unregistered function You don't need to register the final function that you pass to `.resolve()` - if you pass an unregistered function, the library will introspect the function's parameters and resolve them directly. diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py index de69bd5..a553660 100644 --- a/asyncinject/__init__.py +++ b/asyncinject/__init__.py @@ -18,8 +18,8 @@ def __init__(self, *fns, parallel=True, timer=None): for fn in fns: self.register(fn) - def register(self, fn): - self._registry[fn.__name__] = fn + def register(self, fn, *, name=None): + self._registry[name or fn.__name__] = fn # Clear caches: self._graph = None self._reversed = None diff --git a/tests/test_asyncinject.py b/tests/test_asyncinject.py index ad62722..fdd3318 100644 --- a/tests/test_asyncinject.py +++ b/tests/test_asyncinject.py @@ -181,3 +181,29 @@ def three_not_async(one, two): # Test that passing parameters works too result2 = await registry.resolve(fn, one=2) assert result2 == 4 + + +@pytest.mark.asyncio +async def test_register(): + registry = Registry() + + async def one(): + return "one" + + async def two_(): + return "two" + + async def three(one, two): + return one + two + + registry.register(one) + + # Should raise an error if you don't use name= + with pytest.raises(TypeError): + registry.register(two_, "two") + + registry.register(two_, name="two") + + result = await registry.resolve(three) + + assert result == "onetwo"