Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

List as inputs #28

Merged
merged 15 commits into from
Dec 3, 2024
Prev Previous commit
Next Next commit
remove unused argument
  • Loading branch information
mariogeiger committed Nov 21, 2024
commit 262433557ca53b9d3a9dcae00bca8639b97f8d76
Original file line number Diff line number Diff line change
@@ -443,12 +443,10 @@ def forward(
self,
x0: torch.Tensor,
x1: torch.Tensor,
b2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x0, x1 = self._perm(x0, x1)
assert x0.ndim >= 1, x0.ndim
assert x1.ndim >= 1, x1.ndim
assert b2 is None

shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1])
x0 = _reshape(x0, shape)
@@ -504,13 +502,11 @@ def forward(
x0: torch.Tensor,
x1: torch.Tensor,
x2: torch.Tensor,
b3: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x0, x1, x2 = self._perm(x0, x1, x2)
assert x0.ndim >= 1, x0.ndim
assert x1.ndim >= 1, x1.ndim
assert x2.ndim >= 1, x2.ndim
assert b3 is None

shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1])
x0 = _reshape(x0, shape)
Loading