Skip to content

Commit

Permalink
🔥 delegate nan behavior to aggregators
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasvdd committed Feb 13, 2024
1 parent 3ad7dea commit 56f71a8
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 306 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# .

Removed the `check_nans` argument of the FigureResampler construct and its `add_traces` method. This argument was used to check for NaNs in the input data, but this is now handled by the `nan_policy` argument of specific aggregators.




# v0.9.2
`TODO`

# v0.9.1
## Major changes:
Expand Down
28 changes: 24 additions & 4 deletions plotly_resampler/aggregation/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
LTTBDownsampler,
MinMaxDownsampler,
MinMaxLTTBDownsampler,
NaNMinMaxDownsampler,
NaNMinMaxLTTBDownsampler,
# TODO -> integrate NANM4 (after the candlestick PR)
)

from ..aggregation.aggregation_interface import DataAggregator, DataPointSelector
Expand Down Expand Up @@ -171,18 +174,25 @@ class MinMaxAggregator(DataPointSelector):
"""

def __init__(self, **downsample_kwargs):
def __init__(self, nan_policy="omit", **downsample_kwargs):
"""
Parameters
----------
**downsample_kwargs
Keyword arguments passed to the :class:`MinMaxDownsampler`.
- The `parallel` argument is set to False by default.
nan_policy: str, optional
The policy to handle NaNs. Can be 'omit' or 'keep'. By default, 'omit'.
"""
# this downsampler supports all dtypes
super().__init__(**downsample_kwargs)
self.downsampler = MinMaxDownsampler()
if nan_policy not in ("omit", "keep"):
raise ValueError("nan_policy must be either 'omit' or 'keep'")
if nan_policy == "omit":
self.downsampler = MinMaxDownsampler()
else:
self.downsampler = NaNMinMaxDownsampler()

def _arg_downsample(
self,
Expand All @@ -208,21 +218,31 @@ class MinMaxLTTB(DataPointSelector):
Paper: [https://arxiv.org/pdf/2305.00332.pdf](https://arxiv.org/pdf/2305.00332.pdf)
"""

def __init__(self, minmax_ratio: int = 4, **downsample_kwargs):
def __init__(
self, minmax_ratio: int = 4, nan_policy: str = "omit", **downsample_kwargs
):
"""
Parameters
----------
minmax_ratio: int, optional
The ratio between the number of data points in the MinMax-prefetching and
the number of data points that will be outputted by LTTB. By default, 4.
nan_policy: str, optional
The policy to handle NaNs. Can be 'omit' or 'keep'. By default, 'omit'.
**downsample_kwargs
Keyword arguments passed to the `MinMaxLTTBDownsampler`.
- The `parallel` argument is set to False by default.
- The `minmax_ratio` argument is set to 4 by default, which was empirically
proven to be a good default.
"""
self.minmaxlttb = MinMaxLTTBDownsampler()
if nan_policy not in ("omit", "keep"):
raise ValueError("nan_policy must be either 'omit' or 'keep'")
if nan_policy == "omit":
self.minmaxlttb = MinMaxLTTBDownsampler()
else:
self.minmaxlttb = NaNMinMaxLTTBDownsampler()

self.minmax_ratio = minmax_ratio

super().__init__(
Expand Down
55 changes: 5 additions & 50 deletions plotly_resampler/figure_resampler/figure_resampler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,6 @@ def _parse_get_trace_props(
hf_hovertext: Iterable = None,
hf_marker_size: Iterable = None,
hf_marker_color: Iterable = None,
check_nans: bool = True,
) -> _hf_data_container:
"""Parse and capture the possibly high-frequency trace-props in a datacontainer.
Expand All @@ -572,11 +571,6 @@ def _parse_get_trace_props(
hf_hovertext : Iterable, optional
High-frequency trace "hovertext" data, overrides the current trace its
hovertext data.
check_nans: bool, optional
Whether the `hf_y` should be checked for NaNs, by default True.
As checking for NaNs is expensive, this can be disabled when the `hf_y` is
already known to contain no NaNs (or when the downsampler can handle NaNs,
e.g., EveryNthPoint).
Returns
-------
Expand Down Expand Up @@ -677,22 +671,6 @@ def _parse_get_trace_props(
if isinstance(hf_marker_color, (tuple, list, np.ndarray, pd.Series)):
hf_marker_color = np.asarray(hf_marker_color)

# Remove NaNs for efficiency (storing less meaningless data)
# NaNs introduce gaps between enclosing non-NaN data points & might distort
# the resampling algorithms
if check_nans and pd.isna(hf_y).any():
not_nan_mask = ~pd.isna(hf_y)
hf_x = hf_x[not_nan_mask]
hf_y = hf_y[not_nan_mask]
if isinstance(hf_text, np.ndarray):
hf_text = hf_text[not_nan_mask]
if isinstance(hf_hovertext, np.ndarray):
hf_hovertext = hf_hovertext[not_nan_mask]
if isinstance(hf_marker_size, np.ndarray):
hf_marker_size = hf_marker_size[not_nan_mask]
if isinstance(hf_marker_color, np.ndarray):
hf_marker_color = hf_marker_color[not_nan_mask]

# Try to parse the hf_x data if it is of object type or
if len(hf_x) and (hf_x.dtype.type is np.str_ or hf_x.dtype == "object"):
try:
Expand Down Expand Up @@ -876,7 +854,6 @@ def add_trace(
hf_hovertext: Union[str, Iterable] = None,
hf_marker_size: Union[str, Iterable] = None,
hf_marker_color: Union[str, Iterable] = None,
check_nans: bool = True,
**trace_kwargs,
):
"""Add a trace to the figure.
Expand Down Expand Up @@ -932,13 +909,6 @@ def add_trace(
hf_marker_color: Iterable, optional
The original high frequency marker color. If set, this has priority over the
trace its ``marker.color`` argument.
check_nans: boolean, optional
If set to True, the trace's data will be checked for NaNs - which will be
removed. By default True.
As this is a costly operation, it is recommended to set this parameter to
False if you are sure that your data does not contain NaNs (or when the
downsampler can handle NaNs, e.g., EveryNthPoint). This should considerably
speed up the graph construction time.
**trace_kwargs: dict
Additional trace related keyword arguments.
e.g.: row=.., col=..., secondary_y=...
Expand Down Expand Up @@ -1019,7 +989,6 @@ def add_trace(
hf_hovertext,
hf_marker_size,
hf_marker_color,
check_nans,
)

# These traces will determine the autoscale its RANGE!
Expand Down Expand Up @@ -1078,7 +1047,6 @@ def add_traces(
downsamplers: None | List[AbstractAggregator] | AbstractAggregator = None,
gap_handlers: None | List[AbstractGapHandler] | AbstractGapHandler = None,
limit_to_views: List[bool] | bool = False,
check_nans: List[bool] | bool = True,
**traces_kwargs,
):
"""Add traces to the figure.
Expand Down Expand Up @@ -1124,14 +1092,6 @@ def add_traces(
by default False.\n
Remark that setting this parameter to True ensures that low frequency traces
are added to the ``hf_data`` property.
check_nans : None | List[bool] | bool, optional
List of check_nans booleans for the added traces. If set to True, the
trace's datapoints will be checked for NaNs. If a single boolean is passed,
all to be added traces will use this value, by default True.\n
As this is a costly operation, it is recommended to set this parameter to
False if the data is known to contain no NaNs (or when the downsampler can
handle NaNs, e.g., EveryNthPoint). This will considerably speed up the graph
construction time.
**traces_kwargs: dict
Additional trace related keyword arguments.
e.g.: rows=.., cols=..., secondary_ys=...
Expand Down Expand Up @@ -1174,16 +1134,11 @@ def add_traces(
gap_handlers = [gap_handlers] * len(data)
if isinstance(limit_to_views, bool):
limit_to_views = [limit_to_views] * len(data)
if isinstance(check_nans, bool):
check_nans = [check_nans] * len(data)

zipped = zip(
data, max_n_samples, downsamplers, gap_handlers, limit_to_views, check_nans
)
for (
i,
(trace, max_out, downsampler, gap_handler, limit_to_view, check_nan),
) in enumerate(zipped):
zipped = zip(data, max_n_samples, downsamplers, gap_handlers, limit_to_views)
for (i, (trace, max_out, downsampler, gap_handler, limit_to_view)) in enumerate(
zipped
):
if (
trace.type.lower() not in self._high_frequency_traces
or self._hf_data.get(trace.uid) is not None
Expand All @@ -1194,7 +1149,7 @@ def add_traces(
if not limit_to_view and (trace.y is None or len(trace.y) <= max_out_s):
continue

dc = self._parse_get_trace_props(trace, check_nans=check_nan)
dc = self._parse_get_trace_props(trace)
self._hf_data[trace.uid] = self._construct_hf_data_dict(
dc,
trace=trace,
Expand Down
Loading

0 comments on commit 56f71a8

Please # to comment.