Skip to content

Fix obs broadcast mismatch #4700

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged

Conversation

brandonwillard
Copy link
Contributor

This PR casts new_shape values in pymc3.aesaraf.change_rv_size before calling RandomVariable.make_node, which prevents unnecessary Casts in the resulting RandomVariable's size parameter when new_shape is a constant.

As a follow-up and/or alternative, RandomVariable.make_node can be updated in Aesara and accomplish the same thing.

Closes #4652.

@brandonwillard brandonwillard self-assigned this May 15, 2021
@brandonwillard brandonwillard added the winOS windows OS related label May 15, 2021
@brandonwillard brandonwillard linked an issue May 15, 2021 that may be closed by this pull request
@michaelosthege michaelosthege requested a review from twiecki May 15, 2021 12:33
@@ -1116,8 +1116,6 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
):
raise TypeError("Observed data cannot consist of symbolic variables.")

data = pandas_to_array(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this cause problems if a DataFrame is passed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the first things in make_obs_var is another call to pandas_to_array, that's why this one can go away

@@ -156,6 +156,10 @@ def change_rv_size(
size = rv_node.op._infer_shape(size, dist_params)
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)

# Make sure the new size is a tensor. This helps to not unnecessarily pick
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Make sure the new size is a tensor. This helps to not unnecessarily pick
# Make sure the new size is an int64 tensor. This helps to not unnecessarily pick

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also link to the issue here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's discoverable through the commits.
But on #4696 we'll make a few more robustness changes to change_rv_size. I can link the issue when rebasing the reintro_shape branch.

@michaelosthege michaelosthege merged commit 9ab831d into pymc-devs:v4 May 15, 2021
@brandonwillard brandonwillard deleted the fix-obs-broadcast-mismatch branch May 15, 2021 17:41
@brandonwillard
Copy link
Contributor Author

brandonwillard commented May 15, 2021

Unfortunately, the comment change introduced by commit 1716082 is actually more misleading than the original comment, because it makes a false statement instead of one that simply lacks context.

The Cast isn't avoided because new_size is a Variable—or any subclass thereof. For instance, passing new_size=at.as_tensor(np.array(..., dtype="int32")) to change_rv_size will still result in a Casted size parameter in the resulting RandomVariable node.

The Cast is avoided only when new_size isn't a Variable, and is an ndarray, a primitive Python numeric type, a tuple or list of numeric types, etc. In other words, the exact opposite of the new comment is true.

Here's the clarified form of the original comment:

    # Make sure numeric/NumPy `new_size` values are converted to int64, so that
    # the `RandomVariable`'s new `size` doesn't unnecessarily pick up a `Cast`.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug v4 winOS windows OS related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Column vector breaks observed
3 participants