Skip to content

Commit

Permalink
feat: allow rebinning by passing edges or a new axis (#977)
Browse files Browse the repository at this point in the history
* feat: add rebin by edges/axis, fix flow bins

* chore: linting

* fix: handle flow bins on rebin, add tests

* refactor: move rebin logic out to tag.py

* chore: pass mypy

* refactor: cleanup rebin arg names

* Apply suggestions from code review

* Update tag.py

* style: pre-commit fixes

---------

Co-authored-by: Henry Schreiner <HenrySchreinerIII@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 31, 2025
1 parent a6876bb commit 789d325
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 11 deletions.
24 changes: 18 additions & 6 deletions src/boost_histogram/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
reduced: CppHistogram | None = None

# Compute needed slices and projections
for i, ind in enumerate(indexes):
for i, ind in enumerate(indexes): # pylint: disable=too-many-nested-blocks
if isinstance(ind, SupportsIndex):
pick_each[i] = ind.__index__() + (
1 if self.axes[i].traits.underflow else 0
Expand Down Expand Up @@ -967,14 +967,26 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:

new_reduced = reduced.__class__(axes)
new_view = new_reduced.view(flow=True)

j = 1
j = 0
new_j_base = 0
if self.axes[i].traits.underflow:
groups.insert(0, 1)
else:
new_j_base = 1
if self.axes[i].traits.overflow:
groups.append(1)
for new_j, group in enumerate(groups):
for _ in range(group):
pos = [slice(None)] * (i)
new_view[(*pos, new_j + 1, ...)] += _to_view(
reduced_view[(*pos, j, ...)]
)
if new_view.dtype.names:
for field in new_view.dtype.names:
new_view[(*pos, new_j + new_j_base, ...)][
field
] += reduced_view[(*pos, j, ...)][field]
else:
new_view[(*pos, new_j + new_j_base, ...)] += (
reduced_view[(*pos, j, ...)]
)
j += 1

reduced = new_reduced
Expand Down
55 changes: 50 additions & 5 deletions src/boost_histogram/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from builtins import sum
from typing import TYPE_CHECKING, Sequence, TypeVar

import numpy as np

if TYPE_CHECKING:
from uhi.typing.plottable import PlottableAxis

Expand Down Expand Up @@ -112,26 +114,43 @@ def __call__(self, axis: AxisLike) -> int: # noqa: ARG002

class rebin:
__slots__ = (
"axis",
"edges",
"factor",
"groups",
)

def __init__(
self,
factor: int | None = None,
factor_or_axis: int | PlottableAxis | None = None,
/,
*,
factor: int | None = None,
groups: Sequence[int] | None = None,
edges: Sequence[int | float] | None = None,
axis: PlottableAxis | None = None,
) -> None:
if not sum(i is None for i in [factor, groups]) == 1:
raise ValueError("Exactly one, a factor or groups should be provided")
self.factor = factor
if (
sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis])
!= 1
):
raise ValueError("Exactly one argument should be provided")
self.groups = groups
self.edges = edges
self.axis = axis
self.factor = factor
if isinstance(factor_or_axis, int):
self.factor = factor_or_axis
elif factor_or_axis is not None:
self.axis = factor_or_axis

def __repr__(self) -> str:
repr_str = f"{self.__class__.__name__}"
args: dict[str, int | Sequence[int] | None] = {
args: dict[str, int | Sequence[int | float] | PlottableAxis | None] = {
"factor": self.factor,
"groups": self.groups,
"edges": self.edges,
"axis": self.axis,
}
for k, v in args.items():
if v is not None:
Expand All @@ -147,4 +166,30 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
return self.groups
if self.factor is not None:
return [self.factor] * len(axis)
if self.edges is not None or self.axis is not None:
newedges = None
if self.axis is not None and hasattr(self.axis, "edges"):
newedges = self.axis.edges
elif self.edges is not None:
newedges = self.edges

if newedges is not None and hasattr(axis, "edges"):
assert newedges[0] == axis.edges[0], "Edges must start at first bin"
assert newedges[-1] == axis.edges[-1], "Edges must end at last bin"
assert all(
np.isclose(
axis.edges[np.abs(axis.edges - edge).argmin()],
edge,
)
for edge in newedges
), "Edges must be in the axis"
matched_ixes = np.where(
np.isin(
axis.edges,
newedges,
)
)[0]
return [
int(ix - matched_ixes[i]) for i, ix in enumerate(matched_ixes[1:])
]
raise ValueError("No rebinning factor or groups provided")
27 changes: 27 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,33 @@ def test_rebin_1d(metadata):
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])
assert h.axes[0].metadata is hs.axes[0].metadata

hs = h[bh.rebin(edges=[1.0, 1.2, 1.6, 2.2, 5.0])]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])

hs = h[bh.rebin(axis=bh.axis.Variable([1.0, 1.2, 1.6, 2.2, 5.0]))]
assert_array_equal(hs.view(), [1, 0, 0, 3])
assert_array_equal(hs.axes.edges[0], [1.0, 1.2, 1.6, 2.2, 5.0])


def test_rebin_1d_flow():
h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=True))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(), [2, 2])
assert_array_equal(hs.view(flow=True), [1, 2, 2, 1])
assert_array_equal(hs.axes.edges[0], [0.0, 3.0, 5.0])

h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=False, overflow=False))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(flow=True), [0, 2, 2, 0])

h = bh.Histogram(bh.axis.Regular(5, 0, 5, underflow=True, overflow=False))
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
hs = h[bh.rebin(edges=[0, 3, 5.0])]
assert_array_equal(hs.view(flow=True), [1, 2, 2, 0])


def test_shrink_rebin_1d():
h = bh.Histogram(bh.axis.Regular(20, 0, 4))
Expand Down

0 comments on commit 789d325

Please # to comment.