Skip to content

Commit

Permalink
Modify how HSGP is built in PyMC when there are groups (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Apr 6, 2023
1 parent d8444ee commit 341966f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
37 changes: 24 additions & 13 deletions bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,33 +353,35 @@ def build(self, bmb_model):

# Build HSGP and store it in the term.
if self.term.by_levels is not None:
flatten_coeffs = True
coeff_dims = coeff_dims + (f"{label}_by",)
phi_list, sqrt_psd_list = [], []
self.term.hsgp = {}
# Because of the filter in the loop, it will be as if the observations were sorted
# using the values of the 'by' variable.
# This approach helps especially when there are many groups, which causes many zeros
# with other approaches (until PyMC and us have better support for sparse matrices)
indexes_to_unsort = self.term.by.argsort(kind="mergesort").argsort(kind="mergesort")
for i, level in enumerate(self.term.by_levels):
cov_func = covariance_functions[i]
# Notes:
# 'm' doesn't change by group
# We need to use list() in 'm' and 'L' because arrays are not instance of Sequence
hsgp = pm.gp.HSGP(
m=list(self.term.m), # Doesn't change by group
L=list(self.term.L[i]), # 1d array is not a Sequence
m=list(self.term.m),
L=list(self.term.L[i]),
drop_first=self.term.drop_first,
cov_func=cov_func,
)
# Notice we pass all the values, for all the groups.
# Then we only keep the ones for the corresponding group.
phi, sqrt_psd = hsgp.prior_linearized(data)
phi = phi.eval()
phi[self.term.by != i] = 0
phi, sqrt_psd = hsgp.prior_linearized(data[self.term.by == i])
sqrt_psd_list.append(sqrt_psd)
phi_list.append(phi)
phi_list.append(phi.eval())

# Store it for later usage
self.term.hsgp[level] = hsgp

phi = np.hstack(phi_list)
sqrt_psd = pt.stack(sqrt_psd_list, axis=1)
else:
flatten_coeffs = False
(cov_func,) = covariance_functions
self.term.hsgp = pm.gp.HSGP(
m=list(self.term.m),
Expand All @@ -399,9 +401,18 @@ def build(self, bmb_model):
coeffs = pm.Deterministic(f"{label}_weights", coeffs_raw * sqrt_psd, dims=coeff_dims)

# Build deterministic for the HSGP contribution
if flatten_coeffs:
coeffs = coeffs.T.flatten() # Equivalent to .flatten("F")
output = pm.Deterministic(label, phi @ coeffs, dims=contribution_dims)
# If there are groups, we do as many dot products as groups
if self.term.by_levels is not None:
contribution_list = []
for i in range(len(self.term.by_levels)):
contribution_list.append(phi_list[i] @ coeffs[:, i])
# We need to unsort the contributions so they match the original data
contribution = pt.concatenate(contribution_list)[indexes_to_unsort]
# If there are no groups, it's a single dot product
else:
contribution = phi @ coeffs

output = pm.Deterministic(label, contribution, dims=contribution_dims)
return output

def get_covariance_functions(self):
Expand Down
7 changes: 6 additions & 1 deletion bambi/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ def predict(self, idata, data=None, include_group_specific=True, hsgp_dict=None)
else:
maximum_distance = 1

# NOTE:
# The approach here differs from the one in the PyMC implementation.
# Here we have a single dot product with many zeros, while there we have many
# smaller dot products.
# It is subject to change here, but I don't want to mess up dims and coords.
if term.by_levels is not None:
by_values = x_slice[:, -1].astype(int)
x_slice = x_slice[:, :-1]
Expand All @@ -222,7 +227,7 @@ def predict(self, idata, data=None, include_group_specific=True, hsgp_dict=None)
x_slice_centered = (x_slice - term.mean) / maximum_distance
phi = term.hsgp.prior_linearized(x_slice_centered)[0].eval()

# Convert 'phi' and 'sqrt_psd' to xarray.DataArrays for easier math
# Convert 'phi' to xarray.DataArray for easier math
# Notice the extra '_' in the dim name for the weights
phi = xr.DataArray(phi, dims=(response_dim, f"{term_aliased_name}__weights_dim"))
weights = posterior[f"{term_aliased_name}_weights"]
Expand Down

0 comments on commit 341966f

Please # to comment.