diff --git a/arviz/data/io_cmdstanpy.py b/arviz/data/io_cmdstanpy.py index 095de475f1..638fba4844 100644 --- a/arviz/data/io_cmdstanpy.py +++ b/arviz/data/io_cmdstanpy.py @@ -36,7 +36,7 @@ def __init__( dims=None, save_warmup=None, ): - self.posterior = posterior + self.posterior = posterior # CmdStanPy CmdStanMCMC object self.posterior_predictive = posterior_predictive self.predictions = predictions self.prior = prior @@ -57,91 +57,74 @@ def __init__( @requires("posterior") def posterior_to_xarray(self): """Extract posterior samples from output csv.""" - columns = self.posterior.column_names - - # filter posterior_predictive, predictions and log_likelihood - posterior_predictive = self.posterior_predictive - if posterior_predictive is None: - posterior_predictive = [] - elif isinstance(posterior_predictive, str): - posterior_predictive = [ - col for col in columns if posterior_predictive == col.split("[")[0].split(".")[0] - ] - else: - posterior_predictive = [ - col - for col in columns - if any(item == col.split("[")[0].split(".")[0] for item in posterior_predictive) - ] - - predictions = self.predictions - if predictions is None: - predictions = [] - elif isinstance(predictions, str): - predictions = [col for col in columns if predictions == col.split("[")[0].split(".")[0]] - else: - predictions = [ - col - for col in columns - if any(item == col.split("[")[0].split(".")[0] for item in predictions) - ] - - log_likelihood = self.log_likelihood - if log_likelihood is None: - log_likelihood = [] - elif isinstance(log_likelihood, str): - log_likelihood = [ - col for col in columns if log_likelihood == col.split("[")[0].split(".")[0] - ] - else: - log_likelihood = [ - col - for col in columns - if any(item == col.split("[")[0].split(".")[0] for item in log_likelihood) - ] - - invalid_cols = set( - posterior_predictive - + predictions - + log_likelihood - + [col for col in columns if col.endswith("__")] - ) - valid_cols = [col for col in columns if col not in invalid_cols] - data, data_warmup = _unpack_frame( + if not hasattr(self.posterior, "stan_vars_cols"): + return self.posterior_to_xarray_pre_v_0_9_68() + + items = list(self.posterior.stan_vars_cols.keys()) + if self.posterior_predictive is not None: + try: + items = _filter(items, self.posterior_predictive) + except ValueError: + pass + if self.predictions is not None: + try: + items = _filter(items, self.predictions) + except ValueError: + pass + if self.log_likelihood is not None: + try: + items = _filter(items, self.log_likelihood) + except ValueError: + pass + + valid_cols = [] + for item in items: + valid_cols.extend(self.posterior.stan_vars_cols[item]) + + data, data_warmup = _unpack_fit( self.posterior, - columns, - valid_cols, + items, self.save_warmup, ) + # copy dims and coords - Mitzi question: why??? + dims = deepcopy(self.dims) if self.dims is not None else {} + coords = deepcopy(self.coords) if self.coords is not None else {} + return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), - dict_to_dataset( - data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims - ), + dict_to_dataset(data, library=self.cmdstanpy, coords=coords, dims=dims), + dict_to_dataset(data_warmup, library=self.cmdstanpy, coords=coords, dims=dims), ) @requires("posterior") def sample_stats_to_xarray(self): + """Extract sample_stats from prosterior fit.""" + return self.stats_to_xarray(self.posterior) + + @requires("prior") + def sample_stats_prior_to_xarray(self): + """Extract sample_stats from prior fit.""" + return self.stats_to_xarray(self.prior) + + def stats_to_xarray(self, fit): """Extract sample_stats from fit.""" + if not hasattr(fit, "sampler_vars_cols"): + return self.sample_stats_to_xarray_pre_v_0_9_68(fit) + dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + items = list(self.posterior.sampler_vars_cols.keys()) - columns = self.posterior.column_names - valid_cols = [col for col in columns if col.endswith("__")] - data, data_warmup = _unpack_frame( - self.posterior, - columns, - valid_cols, + data, data_warmup = _unpack_fit( + fit, + items, self.save_warmup, ) - - for s_param in list(data.keys()): - s_param_, *_ = s_param.split(".") - name = re.sub("__$", "", s_param_) + for item in items: + name = re.sub("__$", "", item) name = "diverging" if name == "divergent" else name - data[name] = data.pop(s_param).astype(dtypes.get(s_param, float)) + data[name] = data.pop(item).astype(dtypes.get(item, float)) if data_warmup: - data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float)) + data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float)) return ( dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), dict_to_dataset( @@ -153,21 +136,32 @@ def sample_stats_to_xarray(self): @requires("posterior_predictive") def posterior_predictive_to_xarray(self): """Convert posterior_predictive samples to xarray.""" - posterior_predictive = self.posterior_predictive - columns = self.posterior.column_names + return self.predictive_to_xarray(self.posterior_predictive, self.posterior) - if isinstance(posterior_predictive, str): - posterior_predictive = [posterior_predictive] - posterior_predictive = set(posterior_predictive) - valid_cols = [ - col for col in columns if col.split("[")[0].split(".")[0] in posterior_predictive - ] - data, data_warmup = _unpack_frame( - self.posterior, - columns, - valid_cols, - self.save_warmup, - ) + @requires("prior") + @requires("prior_predictive") + def prior_predictive_to_xarray(self): + """Convert prior_predictive samples to xarray.""" + return self.predictive_to_xarray(self.prior_predictive, self.prior) + + def predictive_to_xarray(self, names, fit): + """Convert predictive samples to xarray.""" + predictive = _as_set(names) + + if hasattr(fit, "stan_vars_cols"): + data, data_warmup = _unpack_fit( + fit, + predictive, + self.save_warmup, + ) + else: # pre_v_0_9_68 + valid_cols = _filter_columns(fit.column_names, predictive) + data, data_warmup = _unpack_frame( + fit, + fit.column_names, + valid_cols, + self.save_warmup, + ) return ( dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), @@ -180,19 +174,23 @@ def posterior_predictive_to_xarray(self): @requires("predictions") def predictions_to_xarray(self): """Convert out of sample predictions samples to xarray.""" - predictions = self.predictions - columns = self.posterior.column_names + predictions = _as_set(self.predictions) - if isinstance(predictions, str): - predictions = [predictions] - predictions = set(predictions) - valid_cols = [col for col in columns if col.split("[")[0].split(".")[0] in set(predictions)] - data, data_warmup = _unpack_frame( - self.posterior, - columns, - valid_cols, - self.save_warmup, - ) + if hasattr(self.posterior, "stan_vars_cols"): + data, data_warmup = _unpack_fit( + self.posterior, + predictions, + self.save_warmup, + ) + else: # pre_v_0_9_68 + columns = self.posterior.column_names + valid_cols = _filter_columns(columns, predictions) + data, data_warmup = _unpack_frame( + self.posterior, + columns, + valid_cols, + self.save_warmup, + ) return ( dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), @@ -205,20 +203,23 @@ def predictions_to_xarray(self): @requires("log_likelihood") def log_likelihood_to_xarray(self): """Convert elementwise log likelihood samples to xarray.""" - log_likelihood = self.log_likelihood - columns = self.posterior.column_names - - if isinstance(log_likelihood, str): - log_likelihood = [log_likelihood] - log_likelihood = set(log_likelihood) - valid_cols = [col for col in columns if col.split("[")[0].split(".")[0] in log_likelihood] - data, data_warmup = _unpack_frame( - self.posterior, - columns, - valid_cols, - self.save_warmup, - ) + log_likelihood = _as_set(self.log_likelihood) + if hasattr(self.posterior, "stan_vars_cols"): + data, data_warmup = _unpack_fit( + self.posterior, + log_likelihood, + self.save_warmup, + ) + else: # pre_v_0_9_68 + columns = self.posterior.column_names + valid_cols = _filter_columns(columns, log_likelihood) + data, data_warmup = _unpack_frame( + self.posterior, + columns, + valid_cols, + self.save_warmup, + ) return ( dict_to_dataset( data, @@ -239,87 +240,32 @@ def log_likelihood_to_xarray(self): @requires("prior") def prior_to_xarray(self): """Convert prior samples to xarray.""" - # filter prior_predictive - columns = self.prior.column_names - - # filter posterior_predictive and log_likelihood - prior_predictive = self.prior_predictive - if prior_predictive is None: - prior_predictive = [] - elif isinstance(prior_predictive, str): - prior_predictive = [ - col for col in columns if prior_predictive == col.split("[")[0].split(".")[0] - ] - else: - prior_predictive = [ - col for col in columns if col.split("[")[0].split(".")[0] in set(prior_predictive) - ] - - invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")]) - - valid_cols = [col for col in columns if col not in invalid_cols] - - data, data_warmup = _unpack_frame( - self.prior, - columns, - valid_cols, - self.save_warmup, - ) - - return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), - dict_to_dataset( - data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims - ), - ) - - @requires("prior") - def sample_stats_prior_to_xarray(self): - """Extract sample_stats from fit.""" - dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} - - columns = self.prior.column_names - valid_cols = [col for col in columns if col.endswith("__")] - # copy dims and coords - dims = deepcopy(self.dims) if self.dims is not None else {} - coords = deepcopy(self.coords) if self.coords is not None else {} - - data, data_warmup = _unpack_frame( - self.prior, - columns, - valid_cols, - self.save_warmup, - ) - - for s_param in list(data.keys()): - s_param_, *_ = s_param.split(".") - name = re.sub("__$", "", s_param_) - name = "diverging" if name == "divergent" else name - data[name] = data.pop(s_param).astype(dtypes.get(s_param, float)) - if data_warmup: - data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float)) - return ( - dict_to_dataset(data, library=self.cmdstanpy, coords=coords, dims=dims), - dict_to_dataset(data_warmup, library=self.cmdstanpy, coords=coords, dims=dims), - ) - - @requires("prior") - @requires("prior_predictive") - def prior_predictive_to_xarray(self): - """Convert prior_predictive samples to xarray.""" - prior_predictive = self.prior_predictive - columns = self.prior.column_names - - if isinstance(prior_predictive, str): - prior_predictive = [prior_predictive] - prior_predictive = set(prior_predictive) - valid_cols = [col for col in columns if col.split("[")[0].split(".")[0] in prior_predictive] - data, data_warmup = _unpack_frame( - self.prior, - columns, - valid_cols, - self.save_warmup, - ) + if hasattr(self.posterior, "stan_vars_cols"): + items = list(self.posterior.stan_vars_cols.keys()) + if self.prior_predictive is not None: + try: + items = _filter(items, self.prior_predictive) + except ValueError: + pass + data, data_warmup = _unpack_fit( + self.posterior, + items, + self.save_warmup, + ) + else: # pre_v_0_9_68 + columns = self.prior.column_names + prior_predictive = _as_set(self.prior_predictive) + prior_predictive = _filter_columns(columns, prior_predictive) + + invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")]) + valid_cols = [col for col in columns if col not in invalid_cols] + + data, data_warmup = _unpack_frame( + self.prior, + columns, + valid_cols, + self.save_warmup, + ) return ( dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), @@ -393,10 +339,174 @@ def to_inference_data(self): }, ) + @requires("posterior") + def posterior_to_xarray_pre_v_0_9_68(self): + """Extract posterior samples from output csv.""" + columns = self.posterior.column_names + + # filter posterior_predictive, predictions and log_likelihood + posterior_predictive = self.posterior_predictive + if posterior_predictive is None: + posterior_predictive = [] + elif isinstance(posterior_predictive, str): + posterior_predictive = [ + col for col in columns if posterior_predictive == col.split("[")[0].split(".")[0] + ] + else: + posterior_predictive = [ + col + for col in columns + if any(item == col.split("[")[0].split(".")[0] for item in posterior_predictive) + ] + + predictions = self.predictions + if predictions is None: + predictions = [] + elif isinstance(predictions, str): + predictions = [col for col in columns if predictions == col.split("[")[0].split(".")[0]] + else: + predictions = [ + col + for col in columns + if any(item == col.split("[")[0].split(".")[0] for item in predictions) + ] + + log_likelihood = self.log_likelihood + if log_likelihood is None: + log_likelihood = [] + elif isinstance(log_likelihood, str): + log_likelihood = [ + col for col in columns if log_likelihood == col.split("[")[0].split(".")[0] + ] + else: + log_likelihood = [ + col + for col in columns + if any(item == col.split("[")[0].split(".")[0] for item in log_likelihood) + ] + + invalid_cols = set( + posterior_predictive + + predictions + + log_likelihood + + [col for col in columns if col.endswith("__")] + ) + valid_cols = [col for col in columns if col not in invalid_cols] + data, data_warmup = _unpack_frame( + self.posterior, + columns, + valid_cols, + self.save_warmup, + ) + + return ( + dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), + dict_to_dataset( + data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims + ), + ) + + @requires("posterior") + def sample_stats_to_xarray_pre_v_0_9_68(self, fit): + """Extract sample_stats from fit.""" + dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64} + columns = fit.column_names + valid_cols = [col for col in columns if col.endswith("__")] + data, data_warmup = _unpack_frame( + fit, + columns, + valid_cols, + self.save_warmup, + ) + for s_param in list(data.keys()): + s_param_, *_ = s_param.split(".") + name = re.sub("__$", "", s_param_) + name = "diverging" if name == "divergent" else name + data[name] = data.pop(s_param).astype(dtypes.get(s_param, float)) + if data_warmup: + data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float)) + return ( + dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims), + dict_to_dataset( + data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims + ), + ) + + +def _as_set(spec): + """Uniform representation for args which be name or list of names.""" + if spec is None: + return [] + if isinstance(spec, str): + return [spec] + else: + return set(spec) + + +def _filter(names, spec): + """Remove names from list of names.""" + if isinstance(spec, str): + names.remove(spec) + elif isinstance(spec, list): + for item in spec: + names.remove(item) + elif isinstance(spec, dict): + for item in spec.keys(): + names.remove(item) + return names + + +def _filter_columns(columns, spec): + """Parse variable name from column label, removing element index, if any.""" + return [col for col in columns if col.split("[")[0].split(".")[0] in spec] + + +def _unpack_fit(fit, items, save_warmup): + """Transform fit to dictionary containing ndarrays. + + Parameters + ---------- + data: cmdstanpy.CmdStanMCMC + items: list + save_warmup: bool + + Returns + ------- + dict + key, values pairs. Values are formatted to shape = (chains, draws, *shape) + """ + num_warmup = 0 + if save_warmup: + if not fit._save_warmup: # pylint: disable=protected-access + save_warmup = False + else: + num_warmup = fit.num_draws_warmup + + draws = np.swapaxes(fit.draws(inc_warmup=save_warmup), 0, 1) + sample = {} + sample_warmup = {} + + for item in items: + if item in fit.stan_vars_cols: + col_idxs = fit.stan_vars_cols[item] + elif item in fit.sampler_vars_cols: + col_idxs = fit.sampler_vars_cols[item] + else: + raise ValueError("fit data, unknown variable: {}".format(item)) + if save_warmup: + sample_warmup[item] = draws[:num_warmup, :, col_idxs] + sample[item] = draws[num_warmup:, :, col_idxs] + else: + sample[item] = draws[:, :, col_idxs] + + return sample, sample_warmup + def _unpack_frame(fit, columns, valid_cols, save_warmup): """Transform fit to dictionary containing ndarrays. + Called when fit object created by cmdstanpy version < 0.9.68 + Parameters ---------- data: cmdstanpy.CmdStanMCMC