From ba9580a62ceea0d4310b4ffbef210aaf6e7e9a0b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 20 Nov 2024 07:10:48 -0800 Subject: [PATCH 1/9] test and quick fix for zero batch --- .../primitives/equivariant_tensor_product.py | 5 +---- .../primitives/symmetric_tensor_product.py | 6 ++---- .../cuequivariance_torch/primitives/tensor_product.py | 5 ++++- .../tests/operations/symmetric_contraction_test.py | 10 +++++----- .../tests/operations/tp_channel_wise_test.py | 6 ++++-- docs/changelog.md | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cd0df78..113fc39 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import * +from typing import Optional, Union import torch @@ -70,9 +70,6 @@ def __init__( optimize_fallback: Optional[bool] = None, ): super().__init__() - cue.descriptors.fully_connected_tensor_product( - cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") - ) if not isinstance(layout_in, tuple): layout_in = (layout_in,) * e.num_inputs if len(layout_in) != e.num_inputs: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 9d67863..8738ac1 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -14,15 +14,13 @@ # limitations under the License. import logging import math -import warnings -from typing import * +from typing import Optional import torch import torch.fx import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet -from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) @@ -341,7 +339,7 @@ def forward( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) out = self.f(x1, x0, i0) - out = out.reshape(out.shape[0], -1) + out = out.reshape(out.shape[0], out.shape[1] * self.u) return out diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b50db4f..b3c0a20 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -88,6 +88,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ + if any(x.numel() == 0 for x in args): + use_fallback = True # Empty tensors are not supported by the CUDA kernel + if ( args and args[0].device.type == "cuda" @@ -285,7 +288,7 @@ def forward(self, *args): (math.prod(shape), arg.shape[-1]) ) if math.prod(arg.shape[:-1]) > 1 - else arg.reshape((1, arg.shape[-1])) + else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1])) ) for arg in args ] diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 2ff42f3..62ba30e 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -30,7 +30,8 @@ @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("original_mace", [True, False]) -def test_symmetric_contraction(dtype, layout, original_mace): +@pytest.mark.parametrize("batch", [0, 32]) +def test_symmetric_contraction(dtype, layout, original_mace, batch): mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") irreps_out = mul * cue.Irreps("O3", "0e + 1o") @@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace): original_mace=original_mace, ) - Z = 32 - x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda() - indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda() + x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda() + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda() out = m(x, indices) - assert out.shape == (Z, irreps_out.dim) + assert out.shape == (batch, irreps_out.dim) def from64(shape: tuple[int, ...], data: str) -> torch.Tensor: diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 155c73b..aa39e25 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -31,12 +31,14 @@ @pytest.mark.parametrize("irreps3", list_of_irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("batch", [0, 32]) def test_channel_wise( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps, layout: cue.IrrepsLayout, use_fallback: bool, + batch: int, ): m = cuet.ChannelWiseTensorProduct( irreps1, @@ -49,8 +51,8 @@ def test_channel_wise( dtype=torch.float64, ) - x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() + x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() + x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() out1 = m(x1, x2, use_fallback=use_fallback) diff --git a/docs/changelog.md b/docs/changelog.md index e1636f8..e1a8414 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,3 @@ -# Change Log +# Changelog ```{include} ../CHANGELOG.md From 0bfada92c66dd5e74ffe029fc459e64725b986ad Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 20 Nov 2024 07:17:40 -0800 Subject: [PATCH 2/9] trigger uniform 1d in test --- cuequivariance_torch/tests/operations/tp_channel_wise_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index aa39e25..64c08c3 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -20,7 +20,7 @@ from cuequivariance import descriptors list_of_irreps = [ - cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), ] From fd097c67e0ebae227987bda1352baa8af88fbbaa Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 00:05:48 -0800 Subject: [PATCH 3/9] satisfy linter Signed-off-by: Mario Geiger --- .../experimental/mace/symmetric_contractions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index 59c8e41..ebfc5c7 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cache -from typing import * +from typing import Optional import numpy as np @@ -164,7 +164,7 @@ def U_matrix_real( assert isinstance(ir_out, cue.Irrep) if correlation == 4: - filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) + filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) # noqa E741 else: filter_ir_mid = None From 251fc4d6146e1ffa55e9287b23e294441f873d8e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:19:43 -0800 Subject: [PATCH 4/9] from typing import --- .../cuequivariance_torch/primitives/tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b3c0a20..86f04a0 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import * +from typing import Optional, OrderedDict, Tuple import torch import torch.fx From 3498a32a2e124b499aa4ed7cfe842e83492db91a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:24:46 -0800 Subject: [PATCH 5/9] determine math_dtype earlier --- .../primitives/tensor_product.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 86f04a0..affa7ce 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -47,6 +47,9 @@ def __init__( super().__init__() self.descriptor = descriptor + if math_dtype is None: + math_dtype = torch.get_default_dtype() + try: self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype) except NotImplementedError as e: @@ -116,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): def _tensor_product_fx( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, optimize_einsums: bool, ) -> torch.nn.Module: """ @@ -124,10 +127,6 @@ def _tensor_product_fx( - at least one input operand should have a batch dimension (ndim=2) - the output operand will have a batch dimension (ndim=2) """ - - if math_dtype is None: - math_dtype = torch.get_default_dtype() - descriptor = descriptor.remove_zero_paths() descriptor = descriptor.remove_empty_segments() @@ -313,7 +312,7 @@ def _sum(tensors, *, shape=None, like=None): def _tensor_product_cuda( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, ) -> torch.nn.Module: logger.debug(f"Starting search for a cuda kernel for {descriptor}") @@ -326,9 +325,6 @@ def _tensor_product_cuda( f" Got {descriptor.subscripts}." ) - if math_dtype is None: - math_dtype = torch.get_default_dtype() - if not torch.cuda.is_available(): raise NotImplementedError("CUDA is not available.") From 7f3cf05c1fe200078385a7fc8ce555035a1f9ed7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:30:41 -0800 Subject: [PATCH 6/9] warning with pip commands --- .../cuequivariance_torch/primitives/tensor_product.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index affa7ce..4ac91af 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -57,6 +57,12 @@ def __init__( self.f_cuda = None except ImportError as e: logger.warning(f"CUDA implementation not available: {e}") + logger.warning( + "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" + "Install it with one of the following commands:\n" + "pip install cuequivariance-ops-torch-cu11\n" + "pip install cuequivariance-ops-torch-cu12" + ) self.f_cuda = None self.f_fx = _tensor_product_fx( From 262433557ca53b9d3a9dcae00bca8639b97f8d76 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:41:47 -0800 Subject: [PATCH 7/9] remove unused argument --- .../cuequivariance_torch/primitives/tensor_product.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 4ac91af..736c2aa 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -443,12 +443,10 @@ def forward( self, x0: torch.Tensor, x1: torch.Tensor, - b2: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1 = self._perm(x0, x1) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - assert b2 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) x0 = _reshape(x0, shape) @@ -504,13 +502,11 @@ def forward( x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, - b3: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1, x2 = self._perm(x0, x1, x2) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - assert b3 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) x0 = _reshape(x0, shape) From 91f7fce1457de45fe19547f329e0ceda86c0dd1a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:43:18 -0800 Subject: [PATCH 8/9] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7a140..0c32d2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## Latest Changes + +- Add support for empty batch dimension in `cuequivariance-torch`. + ## 0.1.0 (2024-11-18) - Beta version of cuEquivariance released. From ad2db8d5a540ba2ca54020248395bd6ba6c821b1 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:12:23 -0800 Subject: [PATCH 9/9] add Fixed subtite --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c32d2c..0d75a9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ ## Latest Changes +### Fixed + - Add support for empty batch dimension in `cuequivariance-torch`. ## 0.1.0 (2024-11-18)