From ddf45c7d1550d45aee2be48977ad3ac84db90c7a Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 15 Apr 2025 19:44:18 -0400 Subject: [PATCH 1/5] add to_dict transform --- bayesflow/adapters/transforms/__init__.py | 1 + bayesflow/adapters/transforms/to_dict.py | 41 +++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 bayesflow/adapters/transforms/to_dict.py diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index afc84e6c3..c089f3319 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -21,6 +21,7 @@ from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray +from .to_dict import ToDict from .transform import Transform from ...utils._docs import _add_imports_to_all diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py new file mode 100644 index 000000000..0fede583d --- /dev/null +++ b/bayesflow/adapters/transforms/to_dict.py @@ -0,0 +1,41 @@ +import numpy as np +import pandas as pd + +from keras.saving import ( + register_keras_serializable as serializable, +) + +from .transform import Transform + + +@serializable(package="bayesflow.adapters") +class ToDict(Transform): + """Convert non-dict batches (e.g., pandas.DataFrame) to dict batches""" + + @classmethod + def from_config(cls, config: dict, custom_objects=None): + return cls() + + def get_config(self) -> dict: + return {} + + def forward(self, data, **kwargs) -> dict[str, np.ndarray]: + data = dict(data) + + for key, value in data.items(): + if isinstance(data[key], pd.Series): + if value.dtype == "object": + value = value.astype("category") + + if value.dtype == "category": + value = pd.get_dummies(value) + + value = np.asarray(value).astype("float32", copy=False) + + data[key] = value + + return data + + def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]: + # non-invertible transform + return data From 49174e0148072a939a5b27e4eb8acb60de9a4fa1 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 15 Apr 2025 19:44:27 -0400 Subject: [PATCH 2/5] add to_dict dispatch on adapter --- bayesflow/adapters/adapter.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 5e3b8aaef..fa4dd2b8b 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -755,3 +755,10 @@ def to_array( ) self.transforms.append(transform) return self + + def to_dict(self): + from .transforms import ToDict + + transform = ToDict() + self.transforms.append(transform) + return self From 21cd4fe66c96268fb0fb54152ede34544f789c1d Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 15 Apr 2025 19:44:31 -0400 Subject: [PATCH 3/5] add to_dict tests --- tests/test_adapters/test_adapters.py | 40 ++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 99b48d003..320f6db93 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -5,6 +5,8 @@ import numpy as np import pytest +import bayesflow as bf + def test_cycle_consistency(adapter, random_data): processed = adapter(random_data) @@ -192,3 +194,41 @@ def test_split_transform(adapter, random_data): assert "split_2" in processed assert processed["split_2"].shape == target_shape + + +def test_to_dict_transform(): + import pandas as pd + + data = { + "int32": [1, 2, 3, 4, 5], + "int64": [1, 2, 3, 4, 5], + "float32": [1.0, 2.0, 3.0, 4.0, 5.0], + "float64": [1.0, 2.0, 3.0, 4.0, 5.0], + "object": ["a", "b", "c", "d", "e"], + "category": ["one", "two", "three", "four", "five"], + } + + df = pd.DataFrame(data) + df["int32"] = df["int32"].astype("int32") + df["int64"] = df["int64"].astype("int64") + df["float32"] = df["float32"].astype("float32") + df["float64"] = df["float64"].astype("float64") + df["object"] = df["object"].astype("object") + df["category"] = df["category"].astype("category") + + ad = bf.Adapter().to_dict() + + # drop one element to simulate non-complete data + batch = df.iloc[:-1] + + processed = ad(batch) + + assert isinstance(processed, dict) + assert list(processed.keys()) == ["int32", "int64", "float32", "float64", "object", "category"] + + for key, value in processed.items(): + assert isinstance(value, np.ndarray) + assert value.dtype == "float32" + + # category should have 5 one-hot categories, even though it was only passed 4 + assert processed["category"].shape[-1] == 5 From 2489c0d062b62e1fe5dacfb42c14de6c84f01d86 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 15 Apr 2025 19:48:58 -0400 Subject: [PATCH 4/5] copilot nitpick --- bayesflow/adapters/transforms/to_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py index 0fede583d..30585b0fe 100644 --- a/bayesflow/adapters/transforms/to_dict.py +++ b/bayesflow/adapters/transforms/to_dict.py @@ -23,7 +23,7 @@ def forward(self, data, **kwargs) -> dict[str, np.ndarray]: data = dict(data) for key, value in data.items(): - if isinstance(data[key], pd.Series): + if isinstance(value, pd.Series): if value.dtype == "object": value = value.astype("category") From 9076458eb5c89c21ee87aad6a6c6941211708368 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Thu, 17 Apr 2025 11:41:15 -0400 Subject: [PATCH 5/5] update serialization protocol --- bayesflow/adapters/transforms/split.py | 26 ++++++++--------------- bayesflow/adapters/transforms/to_array.py | 2 +- bayesflow/adapters/transforms/to_dict.py | 6 ++---- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/bayesflow/adapters/transforms/split.py b/bayesflow/adapters/transforms/split.py index ddbef1a43..919db4e08 100644 --- a/bayesflow/adapters/transforms/split.py +++ b/bayesflow/adapters/transforms/split.py @@ -1,16 +1,12 @@ from collections.abc import Sequence import numpy as np -from keras.saving import ( - deserialize_keras_object as deserialize, - register_keras_serializable as serializable, - serialize_keras_object as serialize, -) +from bayesflow.utils.serialization import deserialize, serializable, serialize from .transform import Transform -@serializable(package="bayesflow.adapters") +@serializable class Split(Transform): """This is the effective inverse of the :py:class:`~Concatenate` Transform. @@ -38,20 +34,16 @@ def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Seq @classmethod def from_config(cls, config: dict, custom_objects=None) -> "Split": - return cls( - key=deserialize(config["key"], custom_objects), - into=deserialize(config["into"], custom_objects), - indices_or_sections=deserialize(config["indices_or_sections"], custom_objects), - axis=deserialize(config["axis"], custom_objects), - ) + return cls(**deserialize(config, custom_objects=custom_objects)) def get_config(self) -> dict: - return { - "key": serialize(self.key), - "into": serialize(self.into), - "indices_or_sections": serialize(self.indices_or_sections), - "axis": serialize(self.axis), + config = { + "key": self.key, + "into": self.into, + "indices_or_sections": self.indices_or_sections, + "axis": self.axis, } + return serialize(config) def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) -> dict[str, np.ndarray]: # avoid side effects diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index 93b1bc86b..9d5381ca0 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -7,7 +7,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable(package="bayesflow.adapters") +@serializable class ToArray(ElementwiseTransform): """ Checks provided data for any non-arrays and converts them to numpy arrays. diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py index 30585b0fe..6babb2a40 100644 --- a/bayesflow/adapters/transforms/to_dict.py +++ b/bayesflow/adapters/transforms/to_dict.py @@ -1,14 +1,12 @@ import numpy as np import pandas as pd -from keras.saving import ( - register_keras_serializable as serializable, -) +from bayesflow.utils.serialization import serializable from .transform import Transform -@serializable(package="bayesflow.adapters") +@serializable class ToDict(Transform): """Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""