Skip to content

Commit

Permalink
Prevent underscores from being removed in dim names (#664)
Browse files Browse the repository at this point in the history
* add underscores to dim names

* Fix dim names multinomial family

* Fix dimname categorical family
  • Loading branch information
tomicapretto authored Apr 7, 2023
1 parent 341966f commit d9ca5d9
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_coords(self):
new_coords = {}
for key, value in coords.items():
_, kind = key.split("__")
new_coords[self.term.alias + kind] = value
new_coords[self.term.alias + "__" + kind] = value
return new_coords

def build_distribution(self, prior, label, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion bambi/families/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def posterior_predictive(self, model, posterior, **kwargs):

def get_coords(self, response):
# For the moment, it always uses the first column as reference.
name = response.name + "_dim"
name = get_aliased_name(response) + "_dim"
labels = self.get_levels(response)
return {name: labels[1:]}

Expand Down
2 changes: 1 addition & 1 deletion bambi/families/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_data(self, response):
return np.nonzero(response.term.data)[1]

def get_coords(self, response):
name = response.name + "_dim"
name = get_aliased_name(response) + "_dim"
return {name: [level for level in response.levels if level != response.reference]}

def get_reference(self, response):
Expand Down

0 comments on commit d9ca5d9

Please # to comment.