Skip to content

Commit

Permalink
Read model info to attrs with from_pystan (#1353)
Browse files Browse the repository at this point in the history
* read attrs for pystan

* fix step_size logic and inv_metric for unit_e

* update and fix changelog

* add error handling for attrs

* Fix logging

* Skip main init

* Missing chr

* pystan3 attrs

* add minimal test

* add attrs for sample_stats

* test change

* fix fit to model

* fix changelog

* rewrite pystan attrs

Co-authored-by: ahartikainen <ahartikainen@github.com>
  • Loading branch information
ahartikainen and ahartikainen authored Aug 16, 2020
1 parent 3f1d517 commit f124e64
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 31 deletions.
19 changes: 7 additions & 12 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,16 @@
* Added `circ_var_names` argument to `plot_trace` allowing for circular traceplot (Matplotlib) (#1336)

### Maintenance and fixes
* automatic conversion of list/tuple to numpy array in distplot (#1277)
* plot_posterior: fix overlap of hdi and rope (#1263)
* Automatic conversion of list/tuple to numpy array in distplot (#1277)
* `plot_posterior` fix overlap of hdi and rope (#1263)
* `plot_dist` bins argument error fixed (#1306)
* improve handling of circular variables in `az.summary` (#1313)
* Improve handling of circular variables in `az.summary` (#1313)
* Removed change of default warning in `ELPDData` string representation (#1321)
* update `radon` example dataset to current InferenceData schema specification (#1320)
* update `from_cmdstan` functionality and add warmup groups (#1330 and #1351)
* restructure plotting code to be compatible with mpl>=3.3 (#1312)
* Replaced `_fast_kde()` with `kde()` which now also supports circular variables
via the argument `circular` (#1284).
* Update `radon` example dataset to current InferenceData schema specification (#1320)
* Update `from_cmdstan` functionality and add warmup groups (#1330 and #1351)
* Restructure plotting code to be compatible with mpl>=3.3 (#1312 and #1352)
* Replaced `_fast_kde()` with `kde()` which now also supports circular variables via the argument `circular` (#1284).

### Maintenance and fixes
* plot_posterior: fix overlap of hdi and rope (#1263)
* All the functions that used to call `_fast_kde`() now use `kde()` and have been updated to handle the new types returned (#1284).
* Increased `from_pystan` attrs information content (#1353)

### Deprecation

Expand Down
9 changes: 5 additions & 4 deletions arviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.pyplot import register_cmap, style


# Configure logging before importing arviz internals
_log = logging.getLogger("arviz")


from .data import *
from .plots import *
from .plots import backends
Expand All @@ -22,10 +27,6 @@
style.core.reload_library()


# Configure logging before importing arviz internals
_log = logging.getLogger("arviz")


if not logging.root.handlers:
_handler = logging.StreamHandler()
_formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
Expand Down
141 changes: 129 additions & 12 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import re
import warnings
from collections import OrderedDict
from copy import deepcopy

import numpy as np
import xarray as xr

from ..rcparams import rcParams
from .. import _log
from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires
from .inference_data import InferenceData

Expand Down Expand Up @@ -74,10 +76,14 @@ def posterior_to_xarray(self):
ignore = posterior_predictive + predictions + log_likelihood + ["lp__"]

data, data_warmup = get_draws(posterior, ignore=ignore, warmup=self.save_warmup)

attrs = get_attrs(posterior)
return (
dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
dict_to_dataset(
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
),
dict_to_dataset(
data_warmup, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
),
)

@requires("posterior")
Expand All @@ -93,9 +99,14 @@ def sample_stats_to_xarray(self):
if stat_lp_warmup:
data_warmup["lp"] = stat_lp_warmup["lp__"]

attrs = get_attrs(posterior)
return (
dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
dict_to_dataset(
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
),
dict_to_dataset(
data_warmup, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
),
)

@requires("posterior")
Expand Down Expand Up @@ -170,7 +181,10 @@ def prior_to_xarray(self):
ignore = prior_predictive + ["lp__"]

data, _ = get_draws(prior, ignore=ignore, warmup=False)
return dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims)
attrs = get_attrs(prior)
return dict_to_dataset(
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
)

@requires("prior")
def sample_stats_prior_to_xarray(self):
Expand All @@ -182,7 +196,10 @@ def sample_stats_prior_to_xarray(self):
stat_lp, _ = get_draws(prior, variables="lp__", warmup=False)
data["lp"] = stat_lp["lp__"]

return dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims)
attrs = get_attrs(prior)
return dict_to_dataset(
data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
)

@requires("prior")
@requires("prior_predictive")
Expand Down Expand Up @@ -311,16 +328,23 @@ def posterior_to_xarray(self):
ignore = posterior_predictive + predictions + log_likelihood

data = get_draws_stan3(posterior, model=posterior_model, ignore=ignore)

return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)
attrs = get_attrs_stan3(posterior, model=posterior_model)
return dict_to_dataset(
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
)

@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from posterior."""
posterior = self.posterior
posterior_model = self.posterior_model
data = get_sample_stats_stan3(posterior, ignore="lp__")
data["lp"] = get_sample_stats_stan3(posterior, variables="lp__")["lp"]
return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)

attrs = get_attrs_stan3(posterior, model=posterior_model)
return dict_to_dataset(
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
)

@requires("posterior")
@requires("log_likelihood")
Expand Down Expand Up @@ -380,14 +404,21 @@ def prior_to_xarray(self):
ignore = prior_predictive

data = get_draws_stan3(prior, model=prior_model, ignore=ignore)
return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)
attrs = get_attrs_stan3(prior, model=prior_model)
return dict_to_dataset(
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
)

@requires("prior")
def sample_stats_prior_to_xarray(self):
"""Extract sample_stats_prior from prior."""
prior = self.prior
prior_model = self.prior_model
data = get_sample_stats_stan3(prior)
return dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims)
attrs = get_attrs_stan3(prior, model=prior_model)
return dict_to_dataset(
data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
)

@requires("prior")
@requires("prior_predictive")
Expand Down Expand Up @@ -623,6 +654,73 @@ def get_sample_stats(fit, warmup=False):
return data, data_warmup


def get_attrs(fit):
"""Get attributes from PyStan fit object."""
attrs = {}

try:
attrs["args"] = [deepcopy(holder.args) for holder in fit.sim["samples"]]
except Exception as exp: # pylint: disable=broad-except
_log.warning("Failed to fetch args from fit: %s", exp)
if "args" in attrs:
for arg in attrs["args"]:
if isinstance(arg["init"], bytes):
arg["init"] = arg["init"].decode("utf-8")
try:
attrs["inits"] = np.array([holder.inits for holder in fit.sim["samples"]])
except Exception as exp: # pylint: disable=broad-except
_log.warning("Failed to fetch `args` from fit: %s", exp)

attrs["step_size"] = []
attrs["metric"] = []
attrs["inv_metric"] = []
for holder in fit.sim["samples"]:
try:
step_size = float(
re.search(
r"step\s*size\s*=\s*([0-9]+.?[0-9]+)\s*",
holder.adaptation_info,
flags=re.IGNORECASE,
).group(1)
)
except AttributeError:
step_size = np.nan
attrs["step_size"].append(step_size)

inv_metric_match = re.search(
r"mass matrix:\s*(.*)\s*$", holder.adaptation_info, flags=re.DOTALL
)
if inv_metric_match:
inv_metric_str = inv_metric_match.group(1)
if "Diagonal elements of inverse mass matrix" in holder.adaptation_info:
metric = "diag_e"
inv_metric = np.array(
[float(item) for item in inv_metric_str.strip(" #\n").split(",")]
)
else:
metric = "dense_e"
inv_metric = np.array(
[
list(map(float, item.split(",")))
for item in re.sub(r"#\s", "", inv_metric_str).splitlines()
]
)
else:
metric = "unit_e"
inv_metric = None

attrs["metric"].append(metric)
attrs["inv_metric"].append(inv_metric)

if not attrs["step_size"]:
del attrs["step_size"]

attrs["adaptation_info"] = fit.get_adaptation_info()
attrs["stan_code"] = fit.get_stancode()

return attrs


def get_draws_stan3(fit, model=None, variables=None, ignore=None):
"""Extract draws from PyStan3 fit."""
if ignore is None:
Expand Down Expand Up @@ -682,6 +780,25 @@ def get_sample_stats_stan3(fit, variables=None, ignore=None):
return data


def get_attrs_stan3(fit, model=None):
"""Get attributes from PyStan3 fit and model object."""
attrs = {}
for key in ["num_chains", "num_samples", "num_thin", "num_warmup", "save_warmup"]:
try:
attrs[key] = getattr(fit, key)
except AttributeError as exp:
_log.warning("Failed to access attribute %s in fit object %s", key, exp)

if model is not None:
for key in ["model_name", "program_code", "random_seed"]:
try:
attrs[key] = getattr(model, key)
except AttributeError as exp:
_log.warning("Failed to access attribute %s in model object %s", key, exp)

return attrs


def infer_dtypes(fit, model=None):
"""Infer dtypes from Stan model code.
Expand Down
6 changes: 6 additions & 0 deletions arviz/tests/external_tests/test_data_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def test_inference_data(self, data, eight_schools_params):
}
fails = check_multiple_attrs(test_dict, inference_data2)
assert not fails
assert any(
item in inference_data2.posterior.attrs for item in ["stan_code", "program_code"]
)
assert any(
item in inference_data2.sample_stats.attrs for item in ["stan_code", "program_code"]
)
# inference_data 3
test_dict = {
"posterior_predictive": ["y_hat", "log_lik"],
Expand Down
4 changes: 2 additions & 2 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ def pystan_version():
version = int(pystan.__version__[0])
except ImportError:
try:
import stan as pystan # pylint: disable=import-error
import stan # pylint: disable=import-error

version = int(pystan.__version__[0])
version = int(stan.__version__[0])
except ImportError:
version = None
return version
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ use_parentheses = true
line_length = 100
skip = [
"arviz/plots/backends/bokeh/__init__.py",
"arviz/plots/backends/matplotlib/__init__.py"
"arviz/plots/backends/matplotlib/__init__.py",
"arviz/__init__.py"
]

0 comments on commit f124e64

Please # to comment.