Skip to content

Commit

Permalink
plot_hdi: add exception if x is type np.datetime64 and `smooth=Tr…
Browse files Browse the repository at this point in the history
…ue` (#2016)

* add exception if x is type np.datetime64 in plot_hdi

* updated tests for plot_hdi

* updated changelog

* fixed test

* added y and hdi_data to test

* fixed pylint error

* run black

Co-authored-by: Agustina Arroyuelo <agustinaarroyuelo@gmail.com>
  • Loading branch information
Benjamin T. Vincent and agustinaarroyuelo authored Jul 11, 2022
1 parent b2b9cbc commit 101a0f6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features

### Maintenance and fixes
* Add exception in `az.plot_hdi` for `x` of type `np.datetime64` and `smooth=True` ([2016](https://github.com/arviz-devs/arviz/pull/2016))

### Deprecation

Expand Down
3 changes: 3 additions & 0 deletions arviz/plots/hdiplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def plot_hdi(
raise TypeError(msg.format(x_shape, hdi_shape))

if smooth:
if isinstance(x[0], np.datetime64):
raise TypeError("Cannot deal with x as type datetime. Recommend setting smooth=False.")

if smooth_kwargs is None:
smooth_kwargs = {}
smooth_kwargs.setdefault("window_length", 55)
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,15 @@ def test_plot_hdi_dataset_error(models):
plot_hdi(np.arange(8), hdi_data=hdi_data)


def test_plot_hdi_datetime_error():
"""Check x as datetime raises an error."""
x_data = np.arange(start="2022-01-01", stop="2022-03-01", dtype=np.datetime64)
y_data = np.random.normal(0, 5, (1, 200, x_data.shape[0]))
hdi_data = hdi(y_data)
with pytest.raises(TypeError, match="Cannot deal with x as type datetime."):
plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)


@pytest.mark.parametrize("limits", [(-10.0, 10.0), (-5, 5), (None, None)])
def test_kde_scipy(limits):
"""
Expand Down

0 comments on commit 101a0f6

Please # to comment.