Skip to content

Add Pandas to_dict adapter transform #416

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 6 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,3 +755,10 @@ def to_array(
)
self.transforms.append(transform)
return self

def to_dict(self):
from .transforms import ToDict

transform = ToDict()
self.transforms.append(transform)
return self
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .sqrt import Sqrt
from .standardize import Standardize
from .to_array import ToArray
from .to_dict import ToDict
from .transform import Transform

from ...utils._docs import _add_imports_to_all
Expand Down
26 changes: 9 additions & 17 deletions bayesflow/adapters/transforms/split.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from collections.abc import Sequence
import numpy as np

from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
from bayesflow.utils.serialization import deserialize, serializable, serialize

from .transform import Transform


@serializable(package="bayesflow.adapters")
@serializable
class Split(Transform):
"""This is the effective inverse of the :py:class:`~Concatenate` Transform.

Expand Down Expand Up @@ -38,20 +34,16 @@ def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Seq

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Split":
return cls(
key=deserialize(config["key"], custom_objects),
into=deserialize(config["into"], custom_objects),
indices_or_sections=deserialize(config["indices_or_sections"], custom_objects),
axis=deserialize(config["axis"], custom_objects),
)
return cls(**deserialize(config, custom_objects=custom_objects))

def get_config(self) -> dict:
return {
"key": serialize(self.key),
"into": serialize(self.into),
"indices_or_sections": serialize(self.indices_or_sections),
"axis": serialize(self.axis),
config = {
"key": self.key,
"into": self.into,
"indices_or_sections": self.indices_or_sections,
"axis": self.axis,
}
return serialize(config)

def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
# avoid side effects
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/adapters/transforms/to_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
@serializable
class ToArray(ElementwiseTransform):
"""
Checks provided data for any non-arrays and converts them to numpy arrays.
Expand Down
39 changes: 39 additions & 0 deletions bayesflow/adapters/transforms/to_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import pandas as pd

from bayesflow.utils.serialization import serializable

from .transform import Transform


@serializable
class ToDict(Transform):
"""Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""

@classmethod
def from_config(cls, config: dict, custom_objects=None):
return cls()

Check warning on line 15 in bayesflow/adapters/transforms/to_dict.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/to_dict.py#L15

Added line #L15 was not covered by tests

def get_config(self) -> dict:
return {}

Check warning on line 18 in bayesflow/adapters/transforms/to_dict.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/to_dict.py#L18

Added line #L18 was not covered by tests

def forward(self, data, **kwargs) -> dict[str, np.ndarray]:
data = dict(data)

for key, value in data.items():
if isinstance(value, pd.Series):
if value.dtype == "object":
value = value.astype("category")

if value.dtype == "category":
value = pd.get_dummies(value)

value = np.asarray(value).astype("float32", copy=False)

data[key] = value

return data

def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
# non-invertible transform
return data

Check warning on line 39 in bayesflow/adapters/transforms/to_dict.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/to_dict.py#L39

Added line #L39 was not covered by tests
40 changes: 40 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from bayesflow.utils.serialization import deserialize, serialize

import bayesflow as bf


def test_cycle_consistency(adapter, random_data):
processed = adapter(random_data)
Expand Down Expand Up @@ -190,3 +192,41 @@ def test_split_transform(adapter, random_data):

assert "split_2" in processed
assert processed["split_2"].shape == target_shape


def test_to_dict_transform():
import pandas as pd

data = {
"int32": [1, 2, 3, 4, 5],
"int64": [1, 2, 3, 4, 5],
"float32": [1.0, 2.0, 3.0, 4.0, 5.0],
"float64": [1.0, 2.0, 3.0, 4.0, 5.0],
"object": ["a", "b", "c", "d", "e"],
"category": ["one", "two", "three", "four", "five"],
}

df = pd.DataFrame(data)
df["int32"] = df["int32"].astype("int32")
df["int64"] = df["int64"].astype("int64")
df["float32"] = df["float32"].astype("float32")
df["float64"] = df["float64"].astype("float64")
df["object"] = df["object"].astype("object")
df["category"] = df["category"].astype("category")

ad = bf.Adapter().to_dict()

# drop one element to simulate non-complete data
batch = df.iloc[:-1]

processed = ad(batch)

assert isinstance(processed, dict)
assert list(processed.keys()) == ["int32", "int64", "float32", "float64", "object", "category"]

for key, value in processed.items():
assert isinstance(value, np.ndarray)
assert value.dtype == "float32"

# category should have 5 one-hot categories, even though it was only passed 4
assert processed["category"].shape[-1] == 5