diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d75a9c..867796d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Latest Changes +### Changed + +- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. + ### Fixed - Add support for empty batch dimension in `cuequivariance-torch`. diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index e3e34da..197f5ac 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -126,4 +126,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x, use_fallback=use_fallback) + return self.f([weight, x], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index a9c13b8..9fb4d86 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -96,10 +96,7 @@ def forward( encodings_alpha = encode_rotation_angle(alpha, self.lmax) return self.f( - encodings_gamma, - encodings_beta, - encodings_alpha, - x, + [encodings_gamma, encodings_beta, encodings_alpha, x], use_fallback=use_fallback, ) @@ -194,4 +191,4 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the inversion layer.""" - return self.f(x) + return self.f([x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py index 7a8a6e2..bfd0163 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py @@ -55,6 +55,6 @@ def spherical_harmonics( math_dtype=x.dtype, optimize_fallback=optimize_fallback, ) - y = m(x) + y = m([x]) y = y.reshape(vectors.shape[:-1] + (y.shape[-1],)) return y diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 9c02c19..35f1b56 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -188,4 +188,4 @@ def forward( weight = self.weight weight = weight.flatten(1) - return self.f(weight, x, indices=indices, use_fallback=use_fallback) + return self.f([weight, x], indices=indices, use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index fcd3643..76402e7 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -147,4 +147,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x1, x2, use_fallback=use_fallback) + return self.f([weight, x1, x2], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 4f1dcf4..f33c7f6 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -148,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x1, x2, use_fallback=use_fallback) + return self.f([weight, x1, x2], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index fcc0604..ed0bbd1 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch @@ -41,7 +41,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> x1 = torch.ones(17, e.inputs[1].irreps.dim) >>> x2 = torch.ones(17, e.inputs[2].irreps.dim) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul) - >>> tp(w, x1, x2) + >>> tp([w, x1, x2]) tensor([[0., 0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) @@ -50,7 +50,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> w = torch.ones(3, e.inputs[0].irreps.dim) >>> indices = torch.randint(3, (17,)) - >>> tp(w, x1, x2, indices=indices) + >>> tp([w, x1, x2], indices=indices) tensor([[0., 0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) @@ -138,14 +138,14 @@ def extra_repr(self) -> str: def forward( self, - *inputs: torch.Tensor, + inputs: List[torch.Tensor], indices: Optional[torch.Tensor] = None, use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - inputs: list[torch.Tensor] = list(inputs) + inputs: List[torch.Tensor] = list(inputs) assert len(inputs) == len(self.etp.inputs) for a, dim in zip(inputs, self.operands_dims): @@ -164,7 +164,7 @@ def forward( # TODO: at some point we will have kernel for this assert len(inputs) >= 1 inputs[0] = inputs[0][indices] - output = self.tp(*inputs, use_fallback=use_fallback) + output = self.tp(inputs, use_fallback=use_fallback) if self.symm_tp is not None: if len(inputs) == 1: @@ -174,6 +174,10 @@ def forward( if len(inputs) == 2: [x0, x1] = inputs if indices is None: + torch._assert( + x0.ndim == 2, + f"Expected x0 to have shape (batch, dim), got {x0.shape}", + ) if x0.shape[0] == 1: indices = torch.zeros( (x1.shape[0],), dtype=torch.int32, device=x1.device diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 8738ac1..79d92a6 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -109,7 +109,7 @@ def forward( use_fallback=use_fallback, ) if self.f0 is not None: - out += self.f0() + out += self.f0([]) return out @@ -201,7 +201,7 @@ def forward( torch._assert( x0.ndim == 2, - f"Expected 2 dims (i0.max() + 1, x0_size), got {x0.ndim}", + f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) shape = torch.broadcast_shapes(i0.shape, x1.shape[:-1]) i0 = i0.expand(shape).reshape((math.prod(shape),)) @@ -368,6 +368,6 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f(x0[i0], *[x1] * (f.descriptor.num_operands - 2), use_fallback=True) + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2), use_fallback=True) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index a133b44..1dceb9e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import Optional, OrderedDict, Tuple +from typing import List, Optional, OrderedDict, Tuple import torch import torch.fx @@ -76,12 +76,12 @@ def __repr__(self): ) return f"TensorProduct({self.descriptor} {has_cuda_kernel})" - def forward(self, *args, use_fallback: Optional[bool] = None): + def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = None): r""" Perform the tensor product based on the specified descriptor. Args: - args (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. + inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size of each operand as defined in the tensor product descriptor. use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available and the input @@ -97,16 +97,16 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ - if any(x.numel() == 0 for x in args): + if any(x.numel() == 0 for x in inputs): use_fallback = True # Empty tensors are not supported by the CUDA kernel if ( - args - and args[0].device.type == "cuda" + inputs + and inputs[0].device.type == "cuda" and self.f_cuda is not None and (use_fallback is not True) ): - return self.f_cuda(*args) + return self.f_cuda(*inputs) if use_fallback is False: if self.f_cuda is not None: @@ -119,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." ) - return self.f_fx(*args) + return self.f_fx(inputs) def _tensor_product_fx( @@ -278,7 +278,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.module = module self.descriptor = descriptor - def forward(self, *args): + def forward(self, args): for oid, arg in enumerate(args): torch._assert( arg.shape[-1] == self.descriptor.operands[oid].size, diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 64c08c3..9540c73 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -62,7 +62,7 @@ def test_channel_wise( if layout == cue.mul_ir: d = d.add_or_transpose_modes("u,ui,j,uk+ijk") mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() - out2 = mfx(m.weight, x1, x2, use_fallback=True) + out2 = mfx([m.weight, x1, x2], use_fallback=True) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index 045c6ba..4e197fd 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -59,9 +59,7 @@ def test_fully_connected( d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() out2 = mfx( - m.weight.to(torch.float64), - x1.to(torch.float64), - x2.to(torch.float64), + [m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], use_fallback=True, ).to(out1.dtype) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 1bfd619..04d3aef 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -84,11 +84,11 @@ def test_performance_cuda_vs_fx( ] for _ in range(10): - m(*inputs, use_fallback=False) - m(*inputs, use_fallback=True) + m(inputs, use_fallback=False) + m(inputs, use_fallback=True) def f(ufb: bool): - m(*inputs, use_fallback=ufb) + m(inputs, use_fallback=ufb) torch.cuda.synchronize() t0 = timeit.Timer(lambda: f(False)).timeit(number=10) @@ -130,7 +130,7 @@ def test_precision_cuda_vs_fx( device=device, math_dtype=math_dtype, ) - y0 = m(*inputs, use_fallback=False) + y0 = m(inputs, use_fallback=False) m = cuet.EquivariantTensorProduct( e, @@ -140,7 +140,7 @@ def test_precision_cuda_vs_fx( optimize_fallback=True, ) inputs = map(lambda x: x.to(torch.float64), inputs) - y1 = m(*inputs, use_fallback=True).to(dtype) + y1 = m(inputs, use_fallback=True).to(dtype) torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) @@ -153,4 +153,4 @@ def test_compile(): m_compile = torch.compile(m, fullgraph=True) input1 = torch.randn(100, e.inputs[0].irreps.dim) input2 = torch.randn(100, e.inputs[1].irreps.dim) - m_compile(input1, input2) + m_compile([input1, input2]) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 1171f8b..53e8bfc 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -111,12 +111,12 @@ def test_primitive_tensor_product_cuda_vs_fx( m = cuet.TensorProduct( d, device=device, math_dtype=math_dtype, optimize_fallback=False ) - out1 = m(*inputs, use_fallback=False) + out1 = m(inputs, use_fallback=False) m = cuet.TensorProduct( d, device=device, math_dtype=torch.float64, optimize_fallback=False ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] - out2 = m(*inputs_, use_fallback=True) + out2 = m(inputs_, use_fallback=True) assert out1.shape[:-1] == torch.broadcast_shapes(*batches) assert out1.dtype == dtype diff --git a/docs/tutorials/etp.rst b/docs/tutorials/etp.rst index 1f2b20b..529eda8 100644 --- a/docs/tutorials/etp.rst +++ b/docs/tutorials/etp.rst @@ -94,6 +94,6 @@ We can execute an :class:`cuequivariance.EquivariantTensorProduct` with PyTorch. w = torch.randn(e.inputs[0].irreps.dim) x = torch.randn(e.inputs[1].irreps.dim) - module(w, x) + module([w, x]) Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout. diff --git a/docs/tutorials/stp.rst b/docs/tutorials/stp.rst index 2a098a1..b9516ee 100644 --- a/docs/tutorials/stp.rst +++ b/docs/tutorials/stp.rst @@ -112,7 +112,7 @@ Now we can execute the linear layer with random input and weight tensors. w = torch.randn(d.operands[0].size) x1 = torch.randn(3000, irreps1.dim) - x2 = linear_torch(w, x1) + x2 = linear_torch([w, x1]) assert x2.shape == (3000, irreps2.dim)