Skip to content

Subset arrays #411

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 24 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
69e236d
made initial backend functions for adapter subsetting, need to still …
eodole Apr 8, 2025
9c0da4c
added subsample functionality, to do would be adding them to testing …
eodole Apr 11, 2025
d57aee4
made the take function and ran the linter
eodole Apr 11, 2025
8d834da
changed name of subsampling function
eodole Apr 22, 2025
6c1d503
changed documentation, to be consistent with external notation, rathe…
eodole Apr 22, 2025
2e83846
small formation change to documentation
eodole Apr 22, 2025
dee4534
changed subsample to have sample size and axis in the constructor
eodole Apr 22, 2025
71dc35a
moved transforms in the adapter.py so they're in alphabetical order l…
eodole Apr 22, 2025
6c34a5d
changed random_subsample to maptransform rather than filter transform
eodole Apr 22, 2025
c3640cb
updated documentation with new naming convention
eodole Apr 22, 2025
f17322f
added arguments of take to the constructor
eodole Apr 22, 2025
5312c5f
added feature to specify a percentage of the data to subsample rather…
eodole Apr 22, 2025
5361c04
changed subsample in adapter.py to allow float as an input for the sa…
eodole Apr 22, 2025
504344b
renamed subsample_array and associated classes/functions to RandomSub…
eodole Apr 22, 2025
4218b70
included TypeError to force users to only subsample one dataset at a …
eodole Apr 22, 2025
7e3911b
ran linter
eodole May 6, 2025
350513f
merge dev
eodole May 6, 2025
415b658
rerun formatter
LarsKue May 6, 2025
37598b0
clean up random subsample transform and docs
LarsKue May 6, 2025
ee28392
clean up take transform and docs
LarsKue May 6, 2025
5bbf44a
nitpick clean-up
LarsKue May 6, 2025
676c19f
skip shape check for subsampled adapter transform inverse
LarsKue May 6, 2025
f261b50
fix serialization of new transforms
LarsKue May 6, 2025
87017e8
skip randomly subsampled key in serialization consistency check
LarsKue May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ docs/

# MacOS
.DS_Store

# Rproj
.Rproj.user
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unfamiliar with R. What is this directory used for, and should all other users have it ignored too? Otherwise, please put this in your local .git/info/exclude instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to @stefanradev93, this should be .Rproj

62 changes: 61 additions & 1 deletion bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
Standardize,
ToArray,
Transform,
RandomSubsample,
Take,
)
from .transforms.filter_transform import Predicate

Expand Down Expand Up @@ -665,6 +667,28 @@
self.transforms.append(transform)
return self

def random_subsample(self, key: str, *, sample_size: int | float, axis: int = -1):
"""
Append a :py:class:`~transforms.RandomSubsample` transform to the adapter.

Parameters
----------
key : str or Sequence of str
The name of the variable to subsample.
sample_size : int or float
The number of samples to draw, or a fraction between 0 and 1 of the total number of samples to draw.
axis: int, optional
Which axis to draw samples over. The last axis is used by default.
"""

if not isinstance(key, str):
raise TypeError("Can only subsample one batch entry at a time.")

Check warning on line 685 in bayesflow/adapters/adapter.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/adapter.py#L685

Added line #L685 was not covered by tests

transform = MapTransform({key: RandomSubsample(sample_size=sample_size, axis=axis)})

self.transforms.append(transform)
return self

def rename(self, from_key: str, to_key: str):
"""Append a :py:class:`~transforms.Rename` transform to the adapter.

Expand Down Expand Up @@ -741,7 +765,7 @@
Names of variables to include in the transform.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
**kwargs : dict
**kwargs :
Additional keyword arguments passed to the transform.
"""
transform = FilterTransform(
Expand All @@ -754,6 +778,42 @@
self.transforms.append(transform)
return self

def take(
self,
include: str | Sequence[str] = None,
*,
indices: Sequence[int],
axis: int = -1,
predicate: Predicate = None,
exclude: str | Sequence[str] = None,
):
"""
Append a :py:class:`~transforms.Take` transform to the adapter.

Parameters
----------
include : str or Sequence of str, optional
Names of variables to include in the transform.
indices : Sequence of int
Which indices to take from the data.
axis : int, optional
Which axis to take from. The last axis is used by default.
predicate : Predicate, optional
Function that indicates which variables should be transformed.
exclude : str or Sequence of str, optional
Names of variables to exclude from the transform.
"""
transform = FilterTransform(
transform_constructor=Take,
predicate=predicate,
include=include,
exclude=exclude,
indices=indices,
axis=axis,
)
self.transforms.append(transform)
return self

def to_array(
self,
include: str | Sequence[str] = None,
Expand Down
2 changes: 2 additions & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .to_array import ToArray
from .to_dict import ToDict
from .transform import Transform
from .random_subsample import RandomSubsample
from .take import Take

from ...utils._docs import _add_imports_to_all

Expand Down
48 changes: 48 additions & 0 deletions bayesflow/adapters/transforms/random_subsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
from bayesflow.utils.serialization import serializable, serialize
from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
class RandomSubsample(ElementwiseTransform):
"""
A transform that takes a random subsample of the data within an axis.

Example: adapter.random_subsample("x", sample_size = 3, axis = -1)

"""

def __init__(
self,
sample_size: int | float,
axis: int = -1,
):
super().__init__()
if isinstance(sample_size, float):
if sample_size <= 0 or sample_size >= 1:
ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")

Check warning on line 23 in bayesflow/adapters/transforms/random_subsample.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/random_subsample.py#L22-L23

Added lines #L22 - L23 were not covered by tests
self.sample_size = sample_size
self.axis = axis

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
axis = self.axis
max_sample_size = data.shape[axis]

if isinstance(self.sample_size, int):
sample_size = self.sample_size
else:
sample_size = np.round(self.sample_size * max_sample_size)

Check warning on line 34 in bayesflow/adapters/transforms/random_subsample.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/random_subsample.py#L34

Added line #L34 was not covered by tests

# random sample without replacement
sample_indices = np.random.permutation(max_sample_size)[0 : sample_size - 1]

return np.take(data, sample_indices, axis)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
# non invertible transform
return data

def get_config(self) -> dict:
config = {"sample_size": self.sample_size, "axis": self.axis}

return serialize(config)
31 changes: 31 additions & 0 deletions bayesflow/adapters/transforms/take.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from collections.abc import Sequence
import numpy as np

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable(package="bayesflow.adapters")
class Take(ElementwiseTransform):
"""
A transform to reduce the dimensionality of arrays output by the summary network
Example: adapter.take("x", np.arange(0,3), axis=-1)
"""

def __init__(self, indices: Sequence[int], axis: int = -1):
super().__init__()
self.indices = indices
self.axis = axis

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
return np.take(data, self.indices, self.axis)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
# not a true invertible function
return data

def get_config(self) -> dict:
config = {"indices": self.indices, "axis": self.axis}

return serialize(config)
17 changes: 7 additions & 10 deletions tests/test_adapters/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def adapter():
def serializable_fn(x):
return x

d = (
return (
Adapter()
.to_array()
.as_set(["s1", "s2"])
Expand All @@ -32,12 +32,12 @@ def serializable_fn(x):
.standardize(exclude=["t1", "t2", "o1"])
.drop("d1")
.one_hot("o1", 10)
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "t1", "t2", "o1", "split_1", "split_2"])
.keep(["x", "y", "z1", "p1", "p2", "s1", "s2", "s3", "t1", "t2", "o1", "split_1", "split_2"])
.rename("o1", "o2")
.random_subsample("s3", sample_size=33, axis=0)
.take("s3", indices=np.arange(0, 32), axis=0)
)

return d


@pytest.fixture()
def random_data():
Expand All @@ -58,6 +58,7 @@ def random_data():
"d1": np.random.standard_normal(size=(32, 2)),
"d2": np.random.standard_normal(size=(32, 2)),
"o1": np.random.randint(0, 9, size=(32, 2)),
"s3": np.random.standard_normal(size=(35, 2)),
"u1": np.random.uniform(low=-1, high=2, size=(32, 1)),
"key_to_split": np.random.standard_normal(size=(32, 10)),
}
Expand All @@ -67,7 +68,7 @@ def random_data():
def adapter_log_det_jac():
from bayesflow.adapters import Adapter

adapter = (
return (
Adapter()
.scale("x1", by=2)
.log("p1", p1=True)
Expand All @@ -79,14 +80,12 @@ def adapter_log_det_jac():
.rename("u1", "u")
)

return adapter


@pytest.fixture()
def adapter_log_det_jac_inverse():
from bayesflow.adapters import Adapter

adapter = (
return (
Adapter()
.standardize("x1", mean=1, std=2)
.log("p1")
Expand All @@ -96,5 +95,3 @@ def adapter_log_det_jac_inverse():
.constrain("u1", lower=-1, upper=2)
.scale(["p1", "p2", "p3"], by=3.5)
)

return adapter
7 changes: 7 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def test_cycle_consistency(adapter, random_data):
if key in ["d1", "d2", "p3", "n1", "u1"]:
# dropped
continue
if key == "s3":
# we subsampled this key, so it is expected for its shape to change
continue
assert key in deprocessed
assert np.allclose(value, deprocessed[key])

Expand All @@ -31,6 +34,10 @@ def test_serialize_deserialize(adapter, random_data):
random_data["foo"] = random_data["x1"]
deserialized_processed = deserialized(random_data)
for key, value in processed.items():
if key == "s3":
# skip this key because it is *randomly* subsampled
continue

assert np.allclose(value, deserialized_processed[key])


Expand Down