diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7a140..0d75a9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## Latest Changes + +### Fixed + +- Add support for empty batch dimension in `cuequivariance-torch`. + ## 0.1.0 (2024-11-18) - Beta version of cuEquivariance released. diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index 59c8e41..ebfc5c7 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cache -from typing import * +from typing import Optional import numpy as np @@ -164,7 +164,7 @@ def U_matrix_real( assert isinstance(ir_out, cue.Irrep) if correlation == 4: - filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) + filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) # noqa E741 else: filter_ir_mid = None diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cd0df78..113fc39 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import * +from typing import Optional, Union import torch @@ -70,9 +70,6 @@ def __init__( optimize_fallback: Optional[bool] = None, ): super().__init__() - cue.descriptors.fully_connected_tensor_product( - cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") - ) if not isinstance(layout_in, tuple): layout_in = (layout_in,) * e.num_inputs if len(layout_in) != e.num_inputs: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 9d67863..8738ac1 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -14,15 +14,13 @@ # limitations under the License. import logging import math -import warnings -from typing import * +from typing import Optional import torch import torch.fx import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet -from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) @@ -341,7 +339,7 @@ def forward( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) out = self.f(x1, x0, i0) - out = out.reshape(out.shape[0], -1) + out = out.reshape(out.shape[0], out.shape[1] * self.u) return out diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b50db4f..736c2aa 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import * +from typing import Optional, OrderedDict, Tuple import torch import torch.fx @@ -47,6 +47,9 @@ def __init__( super().__init__() self.descriptor = descriptor + if math_dtype is None: + math_dtype = torch.get_default_dtype() + try: self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype) except NotImplementedError as e: @@ -54,6 +57,12 @@ def __init__( self.f_cuda = None except ImportError as e: logger.warning(f"CUDA implementation not available: {e}") + logger.warning( + "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" + "Install it with one of the following commands:\n" + "pip install cuequivariance-ops-torch-cu11\n" + "pip install cuequivariance-ops-torch-cu12" + ) self.f_cuda = None self.f_fx = _tensor_product_fx( @@ -88,6 +97,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ + if any(x.numel() == 0 for x in args): + use_fallback = True # Empty tensors are not supported by the CUDA kernel + if ( args and args[0].device.type == "cuda" @@ -113,7 +125,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): def _tensor_product_fx( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, optimize_einsums: bool, ) -> torch.nn.Module: """ @@ -121,10 +133,6 @@ def _tensor_product_fx( - at least one input operand should have a batch dimension (ndim=2) - the output operand will have a batch dimension (ndim=2) """ - - if math_dtype is None: - math_dtype = torch.get_default_dtype() - descriptor = descriptor.remove_zero_paths() descriptor = descriptor.remove_empty_segments() @@ -285,7 +293,7 @@ def forward(self, *args): (math.prod(shape), arg.shape[-1]) ) if math.prod(arg.shape[:-1]) > 1 - else arg.reshape((1, arg.shape[-1])) + else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1])) ) for arg in args ] @@ -310,7 +318,7 @@ def _sum(tensors, *, shape=None, like=None): def _tensor_product_cuda( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, ) -> torch.nn.Module: logger.debug(f"Starting search for a cuda kernel for {descriptor}") @@ -323,9 +331,6 @@ def _tensor_product_cuda( f" Got {descriptor.subscripts}." ) - if math_dtype is None: - math_dtype = torch.get_default_dtype() - if not torch.cuda.is_available(): raise NotImplementedError("CUDA is not available.") @@ -438,12 +443,10 @@ def forward( self, x0: torch.Tensor, x1: torch.Tensor, - b2: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1 = self._perm(x0, x1) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - assert b2 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) x0 = _reshape(x0, shape) @@ -499,13 +502,11 @@ def forward( x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, - b3: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1, x2 = self._perm(x0, x1, x2) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - assert b3 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) x0 = _reshape(x0, shape) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 2ff42f3..62ba30e 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -30,7 +30,8 @@ @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("original_mace", [True, False]) -def test_symmetric_contraction(dtype, layout, original_mace): +@pytest.mark.parametrize("batch", [0, 32]) +def test_symmetric_contraction(dtype, layout, original_mace, batch): mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") irreps_out = mul * cue.Irreps("O3", "0e + 1o") @@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace): original_mace=original_mace, ) - Z = 32 - x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda() - indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda() + x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda() + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda() out = m(x, indices) - assert out.shape == (Z, irreps_out.dim) + assert out.shape == (batch, irreps_out.dim) def from64(shape: tuple[int, ...], data: str) -> torch.Tensor: diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 155c73b..64c08c3 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -20,7 +20,7 @@ from cuequivariance import descriptors list_of_irreps = [ - cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), ] @@ -31,12 +31,14 @@ @pytest.mark.parametrize("irreps3", list_of_irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("batch", [0, 32]) def test_channel_wise( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps, layout: cue.IrrepsLayout, use_fallback: bool, + batch: int, ): m = cuet.ChannelWiseTensorProduct( irreps1, @@ -49,8 +51,8 @@ def test_channel_wise( dtype=torch.float64, ) - x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() + x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() + x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() out1 = m(x1, x2, use_fallback=use_fallback) diff --git a/docs/changelog.md b/docs/changelog.md index e1636f8..e1a8414 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,3 @@ -# Change Log +# Changelog ```{include} ../CHANGELOG.md