diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 06206a2c5..097143f13 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -755,3 +755,10 @@ def to_array( ) self.transforms.append(transform) return self + + def to_dict(self): + from .transforms import ToDict + + transform = ToDict() + self.transforms.append(transform) + return self diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index afc84e6c3..c089f3319 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -21,6 +21,7 @@ from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray +from .to_dict import ToDict from .transform import Transform from ...utils._docs import _add_imports_to_all diff --git a/bayesflow/adapters/transforms/split.py b/bayesflow/adapters/transforms/split.py index ddbef1a43..919db4e08 100644 --- a/bayesflow/adapters/transforms/split.py +++ b/bayesflow/adapters/transforms/split.py @@ -1,16 +1,12 @@ from collections.abc import Sequence import numpy as np -from keras.saving import ( - deserialize_keras_object as deserialize, - register_keras_serializable as serializable, - serialize_keras_object as serialize, -) +from bayesflow.utils.serialization import deserialize, serializable, serialize from .transform import Transform -@serializable(package="bayesflow.adapters") +@serializable class Split(Transform): """This is the effective inverse of the :py:class:`~Concatenate` Transform. @@ -38,20 +34,16 @@ def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Seq @classmethod def from_config(cls, config: dict, custom_objects=None) -> "Split": - return cls( - key=deserialize(config["key"], custom_objects), - into=deserialize(config["into"], custom_objects), - indices_or_sections=deserialize(config["indices_or_sections"], custom_objects), - axis=deserialize(config["axis"], custom_objects), - ) + return cls(**deserialize(config, custom_objects=custom_objects)) def get_config(self) -> dict: - return { - "key": serialize(self.key), - "into": serialize(self.into), - "indices_or_sections": serialize(self.indices_or_sections), - "axis": serialize(self.axis), + config = { + "key": self.key, + "into": self.into, + "indices_or_sections": self.indices_or_sections, + "axis": self.axis, } + return serialize(config) def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) -> dict[str, np.ndarray]: # avoid side effects diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index 93b1bc86b..9d5381ca0 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -7,7 +7,7 @@ from .elementwise_transform import ElementwiseTransform -@serializable(package="bayesflow.adapters") +@serializable class ToArray(ElementwiseTransform): """ Checks provided data for any non-arrays and converts them to numpy arrays. diff --git a/bayesflow/adapters/transforms/to_dict.py b/bayesflow/adapters/transforms/to_dict.py new file mode 100644 index 000000000..6babb2a40 --- /dev/null +++ b/bayesflow/adapters/transforms/to_dict.py @@ -0,0 +1,39 @@ +import numpy as np +import pandas as pd + +from bayesflow.utils.serialization import serializable + +from .transform import Transform + + +@serializable +class ToDict(Transform): + """Convert non-dict batches (e.g., pandas.DataFrame) to dict batches""" + + @classmethod + def from_config(cls, config: dict, custom_objects=None): + return cls() + + def get_config(self) -> dict: + return {} + + def forward(self, data, **kwargs) -> dict[str, np.ndarray]: + data = dict(data) + + for key, value in data.items(): + if isinstance(value, pd.Series): + if value.dtype == "object": + value = value.astype("category") + + if value.dtype == "category": + value = pd.get_dummies(value) + + value = np.asarray(value).astype("float32", copy=False) + + data[key] = value + + return data + + def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]: + # non-invertible transform + return data diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 2c9540193..d6215170e 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -5,6 +5,8 @@ from bayesflow.utils.serialization import deserialize, serialize +import bayesflow as bf + def test_cycle_consistency(adapter, random_data): processed = adapter(random_data) @@ -190,3 +192,41 @@ def test_split_transform(adapter, random_data): assert "split_2" in processed assert processed["split_2"].shape == target_shape + + +def test_to_dict_transform(): + import pandas as pd + + data = { + "int32": [1, 2, 3, 4, 5], + "int64": [1, 2, 3, 4, 5], + "float32": [1.0, 2.0, 3.0, 4.0, 5.0], + "float64": [1.0, 2.0, 3.0, 4.0, 5.0], + "object": ["a", "b", "c", "d", "e"], + "category": ["one", "two", "three", "four", "five"], + } + + df = pd.DataFrame(data) + df["int32"] = df["int32"].astype("int32") + df["int64"] = df["int64"].astype("int64") + df["float32"] = df["float32"].astype("float32") + df["float64"] = df["float64"].astype("float64") + df["object"] = df["object"].astype("object") + df["category"] = df["category"].astype("category") + + ad = bf.Adapter().to_dict() + + # drop one element to simulate non-complete data + batch = df.iloc[:-1] + + processed = ad(batch) + + assert isinstance(processed, dict) + assert list(processed.keys()) == ["int32", "int64", "float32", "float64", "object", "category"] + + for key, value in processed.items(): + assert isinstance(value, np.ndarray) + assert value.dtype == "float32" + + # category should have 5 one-hot categories, even though it was only passed 4 + assert processed["category"].shape[-1] == 5