-
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Handle non-async functions too #15
Comments
Relevant code: asyncinject/asyncinject/__init__.py Lines 96 to 121 in ddb71fd
I need to figure out how to have that |
I probably need the import asyncio
async def await_me_maybe(value):
if callable(value):
value = value()
if asyncio.iscoroutine(value):
value = await value
return value |
This seems to work - it passes the tests: diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py
index a553660..3a92f41 100644
--- a/asyncinject/__init__.py
+++ b/asyncinject/__init__.py
@@ -27,7 +27,7 @@ class Registry:
def _make_time_logger(self, awaitable):
async def inner():
start = time.perf_counter()
- result = await awaitable
+ result = await await_me_maybe(awaitable)
end = time.perf_counter()
self.timer(awaitable.__name__, start, end)
return result
@@ -90,8 +90,9 @@ class Registry:
**{k: v for k, v in results.items() if k in self.graph[name]},
)
if self.timer:
- aw = self._make_time_logger(aw)
- return aw
+ return self._make_time_logger(aw)
+ else:
+ return await_me_maybe(aw)
async def _execute_sequential(self, results, ts):
for name in ts.static_order():
@@ -132,3 +133,11 @@ class Registry:
await self._execute_sequential(results, ts)
return results
+
+
+async def await_me_maybe(value):
+ if callable(value):
+ value = value()
+ if asyncio.iscoroutine(value):
+ value = await value
+ return value And in the Python shell (with top-level await thanks to >>> import asyncio
>>> from asyncinject import Registry
>>> one = lambda: 1
>>> two = lambda: 2
>>> three = lambda one, two: one + two
>>>
>>> three
<function <lambda> at 0x107d19870>
>>> three.__name__
'<lambda>'
>>> r = Registry()
>>> r.register(one, name='one')
>>> r.register(two, name='two')
>>> await r.resolve(three)
3 |
Need to write tests that show how that operates when mixed with async functions, in particular the planning and timing stuff. |
Ran into a problem with the intersection of this and the logging mechanism. I changed the diff to this: diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py
index a553660..1e3bd1f 100644
--- a/asyncinject/__init__.py
+++ b/asyncinject/__init__.py
@@ -90,8 +90,9 @@ class Registry:
**{k: v for k, v in results.items() if k in self.graph[name]},
)
if self.timer:
- aw = self._make_time_logger(aw)
- return aw
+ return self._make_time_logger(await_me_maybe(aw))
+ else:
+ return await_me_maybe(aw)
async def _execute_sequential(self, results, ts):
for name in ts.static_order():
@@ -132,3 +133,11 @@ class Registry:
await self._execute_sequential(results, ts)
return results
+
+
+async def await_me_maybe(value):
+ if callable(value):
+ value = value()
+ if asyncio.iscoroutine(value):
+ value = await value
+ return value
diff --git a/tests/test_asyncinject.py b/tests/test_asyncinject.py
index fdd3318..ea11603 100644
--- a/tests/test_asyncinject.py
+++ b/tests/test_asyncinject.py
@@ -187,7 +187,8 @@ async def test_resolve_unregistered_function(use_async):
async def test_register():
registry = Registry()
- async def one():
+ # Mix in a non-async function too:
+ def one():
return "one"
async def two_():
@@ -207,3 +208,26 @@ async def test_register():
result = await registry.resolve(three)
assert result == "onetwo"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("parallel", (True, False))
+async def test_just_sync_functions(parallel):
+ def one():
+ return 1
+
+ def two():
+ return 2
+
+ def three(one, two):
+ return one + two
+
+ timed = []
+
+ registry = Registry(
+ one, two, three, parallel=parallel, timer=lambda *args: timed.append(args)
+ )
+ result = await registry.resolve(three)
+ assert result == 3
+
+ assert False And got this test failure:
|
Here's the problem: asyncinject/asyncinject/__init__.py Lines 27 to 35 in ddb71fd
I'm trying to introspect the name of the function as part of the logging mechanism. |
Actually I think the core problem is here: asyncinject/asyncinject/__init__.py Lines 88 to 94 in ddb71fd
If |
OK, a better solution: the diff --git a/asyncinject/__init__.py b/asyncinject/__init__.py
index a553660..4a4cac9 100644
--- a/asyncinject/__init__.py
+++ b/asyncinject/__init__.py
@@ -86,9 +86,20 @@ class Registry:
return ts
def _get_awaitable(self, name, results):
- aw = self._registry[name](
- **{k: v for k, v in results.items() if k in self.graph[name]},
- )
+ fn = self._registry[name]
+ kwargs = {k: v for k, v in results.items() if k in self.graph[name]}
+
+ awaitable_fn = fn
+
+ if not asyncio.iscoroutinefunction(fn):
+
+ async def _awaitable(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ _awaitable.__name__ = fn.__name__
+ awaitable_fn = _awaitable
+
+ aw = awaitable_fn(**kwargs)
if self.timer:
aw = self._make_time_logger(aw)
return aw |
It may be useful if this could also handle regular
def ...
functions, in addition toasync def
functions.I'm imagining using this for Datasette extras where some extras might be able to operate directly on data that has already been fetched by other functions - e.g. an extra which transforms objects in some way.
Refs:
The text was updated successfully, but these errors were encountered: