Skip to content

Commit 1e63803

Browse files
authored
Drastically Improve Speed of Import (#435)
* optimize serializable decorator by getting the parent package name with sys._getframe instead of inspect.getmodule * optimize _add_imports_to_all * fix numpy-keras interop in tests
1 parent d3b3cbc commit 1e63803

File tree

3 files changed

+37
-27
lines changed

3 files changed

+37
-27
lines changed
+24-16
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
1-
import inspect
1+
import sys
2+
import types
23

34

45
def _add_imports_to_all(include_modules: bool | list[str] = False, exclude: list[str] | None = None):
56
"""Add all global variables to __all__"""
67
if not isinstance(include_modules, (bool, list)):
78
raise ValueError("include_modules must be a boolean or a list of strings")
89

9-
exclude = exclude or []
10-
calling_module = inspect.stack()[1]
11-
local_stack = calling_module[0]
12-
global_vars = local_stack.f_globals
13-
all_vars = global_vars["__all__"] if "__all__" in global_vars else []
14-
included_vars = []
15-
for var_name in set(global_vars.keys()):
16-
if inspect.ismodule(global_vars[var_name]):
17-
if include_modules is True and var_name not in exclude and not var_name.startswith("_"):
18-
included_vars.append(var_name)
19-
elif isinstance(include_modules, list) and var_name in include_modules:
20-
included_vars.append(var_name)
21-
elif var_name not in exclude and not var_name.startswith("_"):
22-
included_vars.append(var_name)
23-
global_vars["__all__"] = sorted(list(set(all_vars).union(included_vars)))
10+
exclude_set = set(exclude or [])
11+
contains = exclude_set.__contains__
12+
mod_type = types.ModuleType
13+
frame = sys._getframe(1)
14+
g: dict = frame.f_globals
15+
existing = set(g.get("__all__", []))
16+
17+
to_add = []
18+
include_list = include_modules if isinstance(include_modules, list) else ()
19+
inc_all = include_modules is True
20+
21+
for name, val in g.items():
22+
if name.startswith("_") or contains(name):
23+
continue
24+
25+
if isinstance(val, mod_type):
26+
if inc_all or name in include_list:
27+
to_add.append(name)
28+
else:
29+
to_add.append(name)
30+
31+
g["__all__"] = sorted(existing.union(to_add))

bayesflow/utils/serialization.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import keras
66
import numpy as np
7+
import sys
78

89
# this import needs to be exactly like this to work with monkey patching
910
from keras.saving import deserialize_keras_object
@@ -97,7 +98,10 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
9798
# we marked this as a type during serialization
9899
obj = obj[len(_type_prefix) :]
99100
tp = keras.saving.get_registered_object(
100-
obj, custom_objects=custom_objects, module_objects=builtins.__dict__ | np.__dict__
101+
# TODO: can we pass module objects without overwriting numpy's dict with builtins?
102+
obj,
103+
custom_objects=custom_objects,
104+
module_objects=np.__dict__ | builtins.__dict__,
101105
)
102106
if tp is None:
103107
raise ValueError(
@@ -117,10 +121,9 @@ def deserialize(obj, custom_objects=None, safe_mode=True, **kwargs):
117121
@allow_args
118122
def serializable(cls, package=None, name=None):
119123
if package is None:
120-
# get the calling module's name, e.g. "bayesflow.networks.inference_network"
121-
stack = inspect.stack()
122-
module = inspect.getmodule(stack[1][0])
123-
package = copy(module.__name__)
124+
frame = sys._getframe(1)
125+
g = frame.f_globals
126+
package = g.get("__name__", "bayesflow")
124127

125128
if name is None:
126129
name = copy(cls.__name__)

tests/test_utils/test_dispatch.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# Import the dispatch functions
55
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
6+
from tests.utils import assert_allclose
67

78
# --- Tests for find_network.py ---
89

@@ -118,23 +119,21 @@ def test_find_pooling_mean():
118119
# Check that a keras Lambda layer is returned
119120
assert isinstance(pooling, keras.layers.Lambda)
120121
# Test that the lambda function produces a mean when applied to a sample tensor.
121-
import numpy as np
122122

123-
sample = np.array([[1, 2], [3, 4]])
123+
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
124124
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
125125
result = pooling.call(sample)
126-
np.testing.assert_allclose(result, sample.mean(axis=-2))
126+
assert_allclose(result, keras.ops.mean(sample, axis=-2))
127127

128128

129129
@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
130130
def test_find_pooling_max_min(name, func):
131131
pooling = find_pooling(name)
132132
assert isinstance(pooling, keras.layers.Lambda)
133-
import numpy as np
134133

135-
sample = np.array([[1, 2], [3, 4]])
134+
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
136135
result = pooling.call(sample)
137-
np.testing.assert_allclose(result, func(sample, axis=-2))
136+
assert_allclose(result, func(sample, axis=-2))
138137

139138

140139
def test_find_pooling_learnable(monkeypatch):

0 commit comments

Comments
 (0)