Skip to content

Commit c169912

Browse files
authored
Add Pandas to_dict adapter transform (#416)
* add to_dict transform * add to_dict dispatch on adapter * add to_dict tests * copilot nitpick * update serialization protocol
1 parent 147dd2d commit c169912

File tree

6 files changed

+97
-18
lines changed

6 files changed

+97
-18
lines changed

bayesflow/adapters/adapter.py

+7
Original file line numberDiff line numberDiff line change
@@ -755,3 +755,10 @@ def to_array(
755755
)
756756
self.transforms.append(transform)
757757
return self
758+
759+
def to_dict(self):
760+
from .transforms import ToDict
761+
762+
transform = ToDict()
763+
self.transforms.append(transform)
764+
return self

bayesflow/adapters/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .sqrt import Sqrt
2222
from .standardize import Standardize
2323
from .to_array import ToArray
24+
from .to_dict import ToDict
2425
from .transform import Transform
2526

2627
from ...utils._docs import _add_imports_to_all

bayesflow/adapters/transforms/split.py

+9-17
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from collections.abc import Sequence
22
import numpy as np
33

4-
from keras.saving import (
5-
deserialize_keras_object as deserialize,
6-
register_keras_serializable as serializable,
7-
serialize_keras_object as serialize,
8-
)
4+
from bayesflow.utils.serialization import deserialize, serializable, serialize
95

106
from .transform import Transform
117

128

13-
@serializable(package="bayesflow.adapters")
9+
@serializable
1410
class Split(Transform):
1511
"""This is the effective inverse of the :py:class:`~Concatenate` Transform.
1612
@@ -38,20 +34,16 @@ def __init__(self, key: str, into: Sequence[str], indices_or_sections: int | Seq
3834

3935
@classmethod
4036
def from_config(cls, config: dict, custom_objects=None) -> "Split":
41-
return cls(
42-
key=deserialize(config["key"], custom_objects),
43-
into=deserialize(config["into"], custom_objects),
44-
indices_or_sections=deserialize(config["indices_or_sections"], custom_objects),
45-
axis=deserialize(config["axis"], custom_objects),
46-
)
37+
return cls(**deserialize(config, custom_objects=custom_objects))
4738

4839
def get_config(self) -> dict:
49-
return {
50-
"key": serialize(self.key),
51-
"into": serialize(self.into),
52-
"indices_or_sections": serialize(self.indices_or_sections),
53-
"axis": serialize(self.axis),
40+
config = {
41+
"key": self.key,
42+
"into": self.into,
43+
"indices_or_sections": self.indices_or_sections,
44+
"axis": self.axis,
5445
}
46+
return serialize(config)
5547

5648
def forward(self, data: dict[str, np.ndarray], strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
5749
# avoid side effects

bayesflow/adapters/transforms/to_array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .elementwise_transform import ElementwiseTransform
88

99

10-
@serializable(package="bayesflow.adapters")
10+
@serializable
1111
class ToArray(ElementwiseTransform):
1212
"""
1313
Checks provided data for any non-arrays and converts them to numpy arrays.
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from bayesflow.utils.serialization import serializable
5+
6+
from .transform import Transform
7+
8+
9+
@serializable
10+
class ToDict(Transform):
11+
"""Convert non-dict batches (e.g., pandas.DataFrame) to dict batches"""
12+
13+
@classmethod
14+
def from_config(cls, config: dict, custom_objects=None):
15+
return cls()
16+
17+
def get_config(self) -> dict:
18+
return {}
19+
20+
def forward(self, data, **kwargs) -> dict[str, np.ndarray]:
21+
data = dict(data)
22+
23+
for key, value in data.items():
24+
if isinstance(value, pd.Series):
25+
if value.dtype == "object":
26+
value = value.astype("category")
27+
28+
if value.dtype == "category":
29+
value = pd.get_dummies(value)
30+
31+
value = np.asarray(value).astype("float32", copy=False)
32+
33+
data[key] = value
34+
35+
return data
36+
37+
def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray]:
38+
# non-invertible transform
39+
return data

tests/test_adapters/test_adapters.py

+40
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from bayesflow.utils.serialization import deserialize, serialize
77

8+
import bayesflow as bf
9+
810

911
def test_cycle_consistency(adapter, random_data):
1012
processed = adapter(random_data)
@@ -190,3 +192,41 @@ def test_split_transform(adapter, random_data):
190192

191193
assert "split_2" in processed
192194
assert processed["split_2"].shape == target_shape
195+
196+
197+
def test_to_dict_transform():
198+
import pandas as pd
199+
200+
data = {
201+
"int32": [1, 2, 3, 4, 5],
202+
"int64": [1, 2, 3, 4, 5],
203+
"float32": [1.0, 2.0, 3.0, 4.0, 5.0],
204+
"float64": [1.0, 2.0, 3.0, 4.0, 5.0],
205+
"object": ["a", "b", "c", "d", "e"],
206+
"category": ["one", "two", "three", "four", "five"],
207+
}
208+
209+
df = pd.DataFrame(data)
210+
df["int32"] = df["int32"].astype("int32")
211+
df["int64"] = df["int64"].astype("int64")
212+
df["float32"] = df["float32"].astype("float32")
213+
df["float64"] = df["float64"].astype("float64")
214+
df["object"] = df["object"].astype("object")
215+
df["category"] = df["category"].astype("category")
216+
217+
ad = bf.Adapter().to_dict()
218+
219+
# drop one element to simulate non-complete data
220+
batch = df.iloc[:-1]
221+
222+
processed = ad(batch)
223+
224+
assert isinstance(processed, dict)
225+
assert list(processed.keys()) == ["int32", "int64", "float32", "float64", "object", "category"]
226+
227+
for key, value in processed.items():
228+
assert isinstance(value, np.ndarray)
229+
assert value.dtype == "float32"
230+
231+
# category should have 5 one-hot categories, even though it was only passed 4
232+
assert processed["category"].shape[-1] == 5

0 commit comments

Comments
 (0)