From ba9580a62ceea0d4310b4ffbef210aaf6e7e9a0b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 20 Nov 2024 07:10:48 -0800 Subject: [PATCH 01/12] test and quick fix for zero batch --- .../primitives/equivariant_tensor_product.py | 5 +---- .../primitives/symmetric_tensor_product.py | 6 ++---- .../cuequivariance_torch/primitives/tensor_product.py | 5 ++++- .../tests/operations/symmetric_contraction_test.py | 10 +++++----- .../tests/operations/tp_channel_wise_test.py | 6 ++++-- docs/changelog.md | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cd0df782..113fc399 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import * +from typing import Optional, Union import torch @@ -70,9 +70,6 @@ def __init__( optimize_fallback: Optional[bool] = None, ): super().__init__() - cue.descriptors.fully_connected_tensor_product( - cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") - ) if not isinstance(layout_in, tuple): layout_in = (layout_in,) * e.num_inputs if len(layout_in) != e.num_inputs: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 9d67863c..8738ac1c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -14,15 +14,13 @@ # limitations under the License. import logging import math -import warnings -from typing import * +from typing import Optional import torch import torch.fx import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet -from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) @@ -341,7 +339,7 @@ def forward( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) out = self.f(x1, x0, i0) - out = out.reshape(out.shape[0], -1) + out = out.reshape(out.shape[0], out.shape[1] * self.u) return out diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b50db4fb..b3c0a203 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -88,6 +88,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ + if any(x.numel() == 0 for x in args): + use_fallback = True # Empty tensors are not supported by the CUDA kernel + if ( args and args[0].device.type == "cuda" @@ -285,7 +288,7 @@ def forward(self, *args): (math.prod(shape), arg.shape[-1]) ) if math.prod(arg.shape[:-1]) > 1 - else arg.reshape((1, arg.shape[-1])) + else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1])) ) for arg in args ] diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 2ff42f35..62ba30ee 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -30,7 +30,8 @@ @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("original_mace", [True, False]) -def test_symmetric_contraction(dtype, layout, original_mace): +@pytest.mark.parametrize("batch", [0, 32]) +def test_symmetric_contraction(dtype, layout, original_mace, batch): mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") irreps_out = mul * cue.Irreps("O3", "0e + 1o") @@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace): original_mace=original_mace, ) - Z = 32 - x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda() - indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda() + x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda() + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda() out = m(x, indices) - assert out.shape == (Z, irreps_out.dim) + assert out.shape == (batch, irreps_out.dim) def from64(shape: tuple[int, ...], data: str) -> torch.Tensor: diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 155c73b2..aa39e25d 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -31,12 +31,14 @@ @pytest.mark.parametrize("irreps3", list_of_irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("batch", [0, 32]) def test_channel_wise( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps, layout: cue.IrrepsLayout, use_fallback: bool, + batch: int, ): m = cuet.ChannelWiseTensorProduct( irreps1, @@ -49,8 +51,8 @@ def test_channel_wise( dtype=torch.float64, ) - x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() + x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() + x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() out1 = m(x1, x2, use_fallback=use_fallback) diff --git a/docs/changelog.md b/docs/changelog.md index e1636f88..e1a84149 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,3 @@ -# Change Log +# Changelog ```{include} ../CHANGELOG.md From 0bfada92c66dd5e74ffe029fc459e64725b986ad Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 20 Nov 2024 07:17:40 -0800 Subject: [PATCH 02/12] trigger uniform 1d in test --- cuequivariance_torch/tests/operations/tp_channel_wise_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index aa39e25d..64c08c37 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -20,7 +20,7 @@ from cuequivariance import descriptors list_of_irreps = [ - cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), ] From fd097c67e0ebae227987bda1352baa8af88fbbaa Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 00:05:48 -0800 Subject: [PATCH 03/12] satisfy linter Signed-off-by: Mario Geiger --- .../experimental/mace/symmetric_contractions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index 59c8e414..ebfc5c70 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cache -from typing import * +from typing import Optional import numpy as np @@ -164,7 +164,7 @@ def U_matrix_real( assert isinstance(ir_out, cue.Irrep) if correlation == 4: - filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) + filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) # noqa E741 else: filter_ir_mid = None From 251fc4d6146e1ffa55e9287b23e294441f873d8e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:19:43 -0800 Subject: [PATCH 04/12] from typing import --- .../cuequivariance_torch/primitives/tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b3c0a203..86f04a0d 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import * +from typing import Optional, OrderedDict, Tuple import torch import torch.fx From 3498a32a2e124b499aa4ed7cfe842e83492db91a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:24:46 -0800 Subject: [PATCH 05/12] determine math_dtype earlier --- .../primitives/tensor_product.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 86f04a0d..affa7cec 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -47,6 +47,9 @@ def __init__( super().__init__() self.descriptor = descriptor + if math_dtype is None: + math_dtype = torch.get_default_dtype() + try: self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype) except NotImplementedError as e: @@ -116,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): def _tensor_product_fx( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, optimize_einsums: bool, ) -> torch.nn.Module: """ @@ -124,10 +127,6 @@ def _tensor_product_fx( - at least one input operand should have a batch dimension (ndim=2) - the output operand will have a batch dimension (ndim=2) """ - - if math_dtype is None: - math_dtype = torch.get_default_dtype() - descriptor = descriptor.remove_zero_paths() descriptor = descriptor.remove_empty_segments() @@ -313,7 +312,7 @@ def _sum(tensors, *, shape=None, like=None): def _tensor_product_cuda( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, ) -> torch.nn.Module: logger.debug(f"Starting search for a cuda kernel for {descriptor}") @@ -326,9 +325,6 @@ def _tensor_product_cuda( f" Got {descriptor.subscripts}." ) - if math_dtype is None: - math_dtype = torch.get_default_dtype() - if not torch.cuda.is_available(): raise NotImplementedError("CUDA is not available.") From 7f3cf05c1fe200078385a7fc8ce555035a1f9ed7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:30:41 -0800 Subject: [PATCH 06/12] warning with pip commands --- .../cuequivariance_torch/primitives/tensor_product.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index affa7cec..4ac91af6 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -57,6 +57,12 @@ def __init__( self.f_cuda = None except ImportError as e: logger.warning(f"CUDA implementation not available: {e}") + logger.warning( + "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" + "Install it with one of the following commands:\n" + "pip install cuequivariance-ops-torch-cu11\n" + "pip install cuequivariance-ops-torch-cu12" + ) self.f_cuda = None self.f_fx = _tensor_product_fx( From 262433557ca53b9d3a9dcae00bca8639b97f8d76 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:41:47 -0800 Subject: [PATCH 07/12] remove unused argument --- .../cuequivariance_torch/primitives/tensor_product.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 4ac91af6..736c2aa4 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -443,12 +443,10 @@ def forward( self, x0: torch.Tensor, x1: torch.Tensor, - b2: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1 = self._perm(x0, x1) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - assert b2 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) x0 = _reshape(x0, shape) @@ -504,13 +502,11 @@ def forward( x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, - b3: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1, x2 = self._perm(x0, x1, x2) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - assert b3 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) x0 = _reshape(x0, shape) From 91f7fce1457de45fe19547f329e0ceda86c0dd1a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:43:18 -0800 Subject: [PATCH 08/12] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7a1409..0c32d2c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## Latest Changes + +- Add support for empty batch dimension in `cuequivariance-torch`. + ## 0.1.0 (2024-11-18) - Beta version of cuEquivariance released. From 4401048d23e2028a7ec5ea0c1717fffae292075a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:08:49 -0800 Subject: [PATCH 09/12] list of inputs --- .../cuequivariance_torch/operations/linear.py | 2 +- .../operations/rotation.py | 7 ++----- .../operations/spherical_harmonics.py | 2 +- .../operations/symmetric_contraction.py | 2 +- .../operations/tp_channel_wise.py | 2 +- .../operations/tp_fully_connected.py | 2 +- .../primitives/equivariant_tensor_product.py | 12 ++++++------ .../primitives/symmetric_tensor_product.py | 4 ++-- .../primitives/tensor_product.py | 18 +++++++++--------- .../tests/operations/tp_channel_wise_test.py | 2 +- .../operations/tp_fully_connected_test.py | 4 +--- .../equivariant_tensor_product_test.py | 10 +++++----- .../tests/primitives/tensor_product_test.py | 4 ++-- docs/tutorials/etp.rst | 2 +- docs/tutorials/stp.rst | 2 +- 15 files changed, 35 insertions(+), 40 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index e3e34da0..197f5ac8 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -126,4 +126,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x, use_fallback=use_fallback) + return self.f([weight, x], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index a9c13b88..9fb4d864 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -96,10 +96,7 @@ def forward( encodings_alpha = encode_rotation_angle(alpha, self.lmax) return self.f( - encodings_gamma, - encodings_beta, - encodings_alpha, - x, + [encodings_gamma, encodings_beta, encodings_alpha, x], use_fallback=use_fallback, ) @@ -194,4 +191,4 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the inversion layer.""" - return self.f(x) + return self.f([x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py index 7a8a6e20..bfd01632 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py @@ -55,6 +55,6 @@ def spherical_harmonics( math_dtype=x.dtype, optimize_fallback=optimize_fallback, ) - y = m(x) + y = m([x]) y = y.reshape(vectors.shape[:-1] + (y.shape[-1],)) return y diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 9c02c193..35f1b568 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -188,4 +188,4 @@ def forward( weight = self.weight weight = weight.flatten(1) - return self.f(weight, x, indices=indices, use_fallback=use_fallback) + return self.f([weight, x], indices=indices, use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index fcd3643c..76402e71 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -147,4 +147,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x1, x2, use_fallback=use_fallback) + return self.f([weight, x1, x2], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 4f1dcf41..f33c7f6b 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -148,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x1, x2, use_fallback=use_fallback) + return self.f([weight, x1, x2], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 113fc399..18461b76 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch @@ -41,7 +41,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> x1 = torch.ones(17, e.inputs[1].irreps.dim) >>> x2 = torch.ones(17, e.inputs[2].irreps.dim) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul) - >>> tp(w, x1, x2) + >>> tp([w, x1, x2]) tensor([[0., 0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) @@ -50,7 +50,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> w = torch.ones(3, e.inputs[0].irreps.dim) >>> indices = torch.randint(3, (17,)) - >>> tp(w, x1, x2, indices=indices) + >>> tp([w, x1, x2], indices=indices) tensor([[0., 0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) @@ -136,14 +136,14 @@ def extra_repr(self) -> str: def forward( self, - *inputs: torch.Tensor, + inputs: List[torch.Tensor], indices: Optional[torch.Tensor] = None, use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - inputs: list[torch.Tensor] = list(inputs) + inputs: List[torch.Tensor] = list(inputs) assert len(inputs) == len(self.etp.inputs) for a, b in zip(inputs, self.etp.inputs): @@ -162,7 +162,7 @@ def forward( # TODO: at some point we will have kernel for this assert len(inputs) >= 1 inputs[0] = inputs[0][indices] - output = self.tp(*inputs, use_fallback=use_fallback) + output = self.tp(inputs, use_fallback=use_fallback) if self.symm_tp is not None: if len(inputs) == 1: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 8738ac1c..386d04a4 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -109,7 +109,7 @@ def forward( use_fallback=use_fallback, ) if self.f0 is not None: - out += self.f0() + out += self.f0([]) return out @@ -368,6 +368,6 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f(x0[i0], *[x1] * (f.descriptor.num_operands - 2), use_fallback=True) + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2), use_fallback=True) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 736c2aa4..644c6126 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import Optional, OrderedDict, Tuple +from typing import List, Optional, OrderedDict, Tuple import torch import torch.fx @@ -76,12 +76,12 @@ def __repr__(self): ) return f"TensorProduct({self.descriptor} {has_cuda_kernel})" - def forward(self, *args, use_fallback: Optional[bool] = None): + def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = None): r""" Perform the tensor product based on the specified descriptor. Args: - args (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. + inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size of each operand as defined in the tensor product descriptor. use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available and the input @@ -97,16 +97,16 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ - if any(x.numel() == 0 for x in args): + if any(x.numel() == 0 for x in inputs): use_fallback = True # Empty tensors are not supported by the CUDA kernel if ( - args - and args[0].device.type == "cuda" + inputs + and inputs[0].device.type == "cuda" and self.f_cuda is not None and (use_fallback is not True) ): - return self.f_cuda(*args) + return self.f_cuda(*inputs) if use_fallback is False: if self.f_cuda is not None: @@ -119,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." ) - return self.f_fx(*args) + return self.f_fx(inputs) def _tensor_product_fx( @@ -278,7 +278,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.module = module self.descriptor = descriptor - def forward(self, *args): + def forward(self, args): for oid, arg in enumerate(args): torch._assert( arg.shape[-1] == self.descriptor.operands[oid].size, diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 64c08c37..9540c732 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -62,7 +62,7 @@ def test_channel_wise( if layout == cue.mul_ir: d = d.add_or_transpose_modes("u,ui,j,uk+ijk") mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() - out2 = mfx(m.weight, x1, x2, use_fallback=True) + out2 = mfx([m.weight, x1, x2], use_fallback=True) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index d00e6baf..64944fb6 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -59,9 +59,7 @@ def test_fully_connected( d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() out2 = mfx( - m.weight.to(torch.float64), - x1.to(torch.float64), - x2.to(torch.float64), + [m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], use_fallback=True, ).to(out1.dtype) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index aa1b0ddb..0700a2c7 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -84,11 +84,11 @@ def test_performance_cuda_vs_fx( ] for _ in range(10): - m(*inputs, use_fallback=False) - m(*inputs, use_fallback=True) + m(inputs, use_fallback=False) + m(inputs, use_fallback=True) def f(ufb: bool): - m(*inputs, use_fallback=ufb) + m(inputs, use_fallback=ufb) torch.cuda.synchronize() t0 = timeit.Timer(lambda: f(False)).timeit(number=10) @@ -130,7 +130,7 @@ def test_precision_cuda_vs_fx( device=device, math_dtype=math_dtype, ) - y0 = m(*inputs, use_fallback=False) + y0 = m(inputs, use_fallback=False) m = cuet.EquivariantTensorProduct( e, @@ -140,6 +140,6 @@ def test_precision_cuda_vs_fx( optimize_fallback=True, ) inputs = map(lambda x: x.to(torch.float64), inputs) - y1 = m(*inputs, use_fallback=True).to(dtype) + y1 = m(inputs, use_fallback=True).to(dtype) torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 1171f8b7..53e8bfc3 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -111,12 +111,12 @@ def test_primitive_tensor_product_cuda_vs_fx( m = cuet.TensorProduct( d, device=device, math_dtype=math_dtype, optimize_fallback=False ) - out1 = m(*inputs, use_fallback=False) + out1 = m(inputs, use_fallback=False) m = cuet.TensorProduct( d, device=device, math_dtype=torch.float64, optimize_fallback=False ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] - out2 = m(*inputs_, use_fallback=True) + out2 = m(inputs_, use_fallback=True) assert out1.shape[:-1] == torch.broadcast_shapes(*batches) assert out1.dtype == dtype diff --git a/docs/tutorials/etp.rst b/docs/tutorials/etp.rst index 1f2b20bf..529eda83 100644 --- a/docs/tutorials/etp.rst +++ b/docs/tutorials/etp.rst @@ -94,6 +94,6 @@ We can execute an :class:`cuequivariance.EquivariantTensorProduct` with PyTorch. w = torch.randn(e.inputs[0].irreps.dim) x = torch.randn(e.inputs[1].irreps.dim) - module(w, x) + module([w, x]) Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout. diff --git a/docs/tutorials/stp.rst b/docs/tutorials/stp.rst index 2a098a19..b9516ee5 100644 --- a/docs/tutorials/stp.rst +++ b/docs/tutorials/stp.rst @@ -112,7 +112,7 @@ Now we can execute the linear layer with random input and weight tensors. w = torch.randn(d.operands[0].size) x1 = torch.randn(3000, irreps1.dim) - x2 = linear_torch(w, x1) + x2 = linear_torch([w, x1]) assert x2.shape == (3000, irreps2.dim) From ad2db8d5a540ba2ca54020248395bd6ba6c821b1 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:12:23 -0800 Subject: [PATCH 10/12] add Fixed subtite --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c32d2c6..0d75a9ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ ## Latest Changes +### Fixed + - Add support for empty batch dimension in `cuequivariance-torch`. ## 0.1.0 (2024-11-18) From 889051ad16c421bc912cb6733b18665c72b91341 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:13:37 -0800 Subject: [PATCH 11/12] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d75a9ce..867796d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Latest Changes +### Changed + +- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. + ### Fixed - Add support for empty batch dimension in `cuequivariance-torch`. From c8de1858760b8112d3fde2766110b089c9e1bdf0 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 3 Dec 2024 01:43:47 -0800 Subject: [PATCH 12/12] fix --- .../primitives/equivariant_tensor_product.py | 4 ++++ .../primitives/symmetric_tensor_product.py | 2 +- .../tests/primitives/equivariant_tensor_product_test.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 2381dcad..ed0bbd12 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -174,6 +174,10 @@ def forward( if len(inputs) == 2: [x0, x1] = inputs if indices is None: + torch._assert( + x0.ndim == 2, + f"Expected x0 to have shape (batch, dim), got {x0.shape}", + ) if x0.shape[0] == 1: indices = torch.zeros( (x1.shape[0],), dtype=torch.int32, device=x1.device diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 386d04a4..79d92a65 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -201,7 +201,7 @@ def forward( torch._assert( x0.ndim == 2, - f"Expected 2 dims (i0.max() + 1, x0_size), got {x0.ndim}", + f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) shape = torch.broadcast_shapes(i0.shape, x1.shape[:-1]) i0 = i0.expand(shape).reshape((math.prod(shape),)) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 110b3cf3..04d3aef3 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -153,4 +153,4 @@ def test_compile(): m_compile = torch.compile(m, fullgraph=True) input1 = torch.randn(100, e.inputs[0].irreps.dim) input2 = torch.randn(100, e.inputs[1].irreps.dim) - m_compile(input1, input2) + m_compile([input1, input2])