Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

extend model arg usage in io_pymc3 to fix plot_ppc with prior #1045

Merged
merged 10 commits into from
Feb 11, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* Fix `point_estimate` in `plot_posterior` (#1038)
* Fix interpolation `hpd_plot` (#1039)
* Fix `io_pymc3.py` to handle models with `potentials` (#1043)
* Fix several inconsistencies between schema and `from_pymc3` implementation
in groups `prior`, `prior_predictive` and `observed_data` (#1045)

### Deprecation

Expand Down
62 changes: 51 additions & 11 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""PyMC3-specific conversion code."""
import logging
import warnings
from typing import Dict, List, Any, Optional, TYPE_CHECKING
from types import ModuleType

import numpy as np
import xarray as xr
from .. import utils
from .inference_data import InferenceData, concat
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs, CoordSpec, DimSpec

if TYPE_CHECKING:
import pymc3 as pm
Expand Down Expand Up @@ -86,6 +87,14 @@ def __init__(
else:
self.nchains = self.ndraws = 0

if model is None:
warnings.warn(
"Using `from_pymc3` without the model will be deprecated in a future release. "
"Not using the model will return less accurate and less useful results. "
"Make sure you use the model argument or call from_pymc3 within a model context.",
PendingDeprecationWarning,
)

self.prior = prior
self.posterior_predictive = posterior_predictive
self.predictions = predictions
Expand Down Expand Up @@ -121,8 +130,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
has_observations = False
if self.trace is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like leaving out the check that tells the caller why they are not getting observations out of their trace when they expect it. Perhaps replace my error with a UserWarning?
It seems like Arviz tries to get whatever information it can out of a trace, and ignores whatever it can't figure out. This makes it relatively robust, but means that the user doesn't know that they could get more information into the InferenceData by providing a model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely agree, I was just waiting in order to replace it with an informative warning or a deprecation warning.

assert self.model is not None, "Cannot identify observations without PymC3 model"
if self.model is not None:
if any((hasattr(obs, "observations") for obs in self.model.observed_RVs)):
has_observations = True
if has_observations:
Expand Down Expand Up @@ -221,11 +229,9 @@ def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
if self.prior is None:
return {"prior": None, "prior_predictive": None}
if self.trace is not None:
prior_vars = self.pymc3.util.get_default_varnames( # pylint: disable=no-member
self.trace.varnames, include_transformed=False
)
prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
if self.observations is not None:
prior_predictive_vars = list(self.observations.keys())
prior_vars = [key for key in self.prior.keys() if key not in prior_predictive_vars]
else:
prior_vars = list(self.prior.keys())
prior_predictive_vars = None
Expand All @@ -250,6 +256,8 @@ def priors_to_xarray(self):
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
if self.predictions:
return None
if self.dims is None:
dims = {}
else:
Expand Down Expand Up @@ -347,9 +355,41 @@ def to_inference_data(self):


def from_pymc3(
trace=None, *, prior=None, posterior_predictive=None, coords=None, dims=None, model=None
):
"""Convert pymc3 data into an InferenceData object."""
trace: Optional[MultiTrace] = None,
*,
prior: Optional[Dict[str, Any]] = None,
posterior_predictive: Optional[Dict[str, Any]] = None,
coords: Optional[CoordSpec] = None,
dims: Optional[DimSpec] = None,
model: Optional[Model] = None
) -> InferenceData:
"""Convert pymc3 data into an InferenceData object.

All three of them are optional arguments, but at least one of ``trace``,
``prior`` and ``posterior_predictive`` must be present.

Parameters
----------
trace : pymc3.MultiTrace, optional
Trace generated from MCMC sampling.
prior : dict, optional
Dictionary with the variable names as keys, and values numpy arrays
containing prior and prior predictive samples.
posterior_predictive : dict, optional
Dictionary with the variable names as keys, and values numpy arrays
containing posterior predictive samples.
coords : dict of {str: array-like}, optional
Map of coordinate names to coordinate values
dims : dict of {str: list of str}, optional
Map of variable names to the coordinate names to use to index its dimensions.
model : pymc3.Model, optional
Model used to generate ``trace``. It is not necessary to pass ``model`` if in
``with`` context.

Returns
-------
InferenceData
"""
return PyMC3Converter(
trace=trace,
prior=prior,
Expand Down
10 changes: 7 additions & 3 deletions arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def plot_ppc(
jitter,
total_pp_samples,
legend,
group,
markersize,
animation_kwargs,
num_pp_samples,
Expand Down Expand Up @@ -72,7 +73,9 @@ def plot_ppc(
plot_kwargs = {"color": "C5", "alpha": alpha, "linewidth": 0.5 * linewidth}
if dtype == "i":
plot_kwargs["drawstyle"] = "steps-pre"
ax_i.plot([], color="C5", label="Posterior predictive {}".format(pp_var_name))
ax_i.plot(
[], color="C5", label="{} predictive {}".format(group.capitalize(), pp_var_name)
)

if dtype == "f":
plot_kde(
Expand Down Expand Up @@ -126,6 +129,7 @@ def plot_ppc(
ax_i.plot(x_s, y_s, **plot_kwargs)

if mean:
label = "{} predictive mean {}".format(group.capitalize(), pp_var_name)
if dtype == "f":
rep = len(pp_densities)
len_density = len(pp_densities[0])
Expand All @@ -143,7 +147,7 @@ def plot_ppc(
linestyle="--",
linewidth=linewidth,
zorder=2,
label="Posterior predictive mean {}".format(pp_var_name),
label=label,
)
else:
vals = pp_vals.flatten()
Expand All @@ -155,7 +159,7 @@ def plot_ppc(
hist,
color="C0",
linewidth=linewidth,
label="Posterior predictive mean {}".format(pp_var_name),
label=label,
zorder=2,
linestyle="--",
drawstyle=plot_kwargs["drawstyle"],
Expand Down
2 changes: 2 additions & 0 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def plot_ppc(
jitter=jitter,
total_pp_samples=total_pp_samples,
legend=legend,
group=group,
markersize=markersize,
animation_kwargs=animation_kwargs,
num_pp_samples=num_pp_samples,
Expand All @@ -314,6 +315,7 @@ def plot_ppc(
ppcplot_kwargs.pop("animated")
ppcplot_kwargs.pop("animation_kwargs")
ppcplot_kwargs.pop("legend")
ppcplot_kwargs.pop("group")
ppcplot_kwargs.pop("xt_labelsize")
ppcplot_kwargs.pop("ax_labelsize")

Expand Down
48 changes: 40 additions & 8 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import pickle
import sys
import logging
from typing import Dict, List, Tuple, Union
import pytest
import numpy as np

from ..data import from_dict
from ..data import from_dict, InferenceData


_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,16 +158,24 @@ class Models:
return Models()


def check_multiple_attrs(test_dict, parent):
def check_multiple_attrs(
test_dict: Dict[str, List[str]], parent: InferenceData
) -> List[Union[str, Tuple[str, str]]]:
"""Perform multiple hasattr checks on InferenceData objects.

It is thought to first check if the parent object contains a given dataset,
and then (if present) check the attributes of the dataset.

Args
----
test_dict: dict
Its structure should be `{dataset1_name: [var1, var2], dataset2_name: [var]}`
Given the ouput of the function, all missmatches between expectation and reality can
be retrieved: a single string indicates a group mismatch and a tuple of strings
``(group, var)`` indicates a mismatch in the variable ``var`` of ``group``.

Parameters
----------
test_dict: dict of {str : list of str}
Its structure should be `{dataset1_name: [var1, var2], dataset2_name: [var]}`.
A ``~`` at the beggining of a dataset or variable name indicates the name NOT
being present must be asserted.
parent: InferenceData
InferenceData object on which to check the attributes.

Expand All @@ -176,13 +185,36 @@ def check_multiple_attrs(test_dict, parent):
List containing the failed checks. It will contain either the dataset_name or a
tuple (dataset_name, var) for all non present attributes.

Examples
--------
The output below indicates that ``posterior`` group was expected but not found, and
variables ``a`` and ``b``:

["posterior", ("prior", "a"), ("prior", "b")]

Another example could be the following:

[("posterior", "a"), "~observed_data", ("sample_stats", "~log_likelihood")]

In this case, the output indicates that variable ``a`` was not found in ``posterior``
as it was expected, however, in the other two cases, the preceding ``~`` (kept from the
input negation notation) indicates that ``observed_data`` group should not be present
but was found in the InferenceData and that ``log_likelihood`` variable was found
in ``sample_stats``, also against what was expected.

"""
failed_attrs = []
for dataset_name, attributes in test_dict.items():
if hasattr(parent, dataset_name):
if dataset_name.startswith("~"):
if hasattr(parent, dataset_name[1:]):
failed_attrs.append(dataset_name)
elif hasattr(parent, dataset_name):
dataset = getattr(parent, dataset_name)
for attribute in attributes:
if not hasattr(dataset, attribute):
if attribute.startswith("~"):
if hasattr(dataset, attribute[1:]):
failed_attrs.append((dataset_name, attribute))
elif not hasattr(dataset, attribute):
failed_attrs.append((dataset_name, attribute))
else:
failed_attrs.append(dataset_name)
Expand Down
Loading