From ba9580a62ceea0d4310b4ffbef210aaf6e7e9a0b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 20 Nov 2024 07:10:48 -0800 Subject: [PATCH 01/44] test and quick fix for zero batch --- .../primitives/equivariant_tensor_product.py | 5 +---- .../primitives/symmetric_tensor_product.py | 6 ++---- .../cuequivariance_torch/primitives/tensor_product.py | 5 ++++- .../tests/operations/symmetric_contraction_test.py | 10 +++++----- .../tests/operations/tp_channel_wise_test.py | 6 ++++-- docs/changelog.md | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cd0df78..113fc39 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import * +from typing import Optional, Union import torch @@ -70,9 +70,6 @@ def __init__( optimize_fallback: Optional[bool] = None, ): super().__init__() - cue.descriptors.fully_connected_tensor_product( - cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") - ) if not isinstance(layout_in, tuple): layout_in = (layout_in,) * e.num_inputs if len(layout_in) != e.num_inputs: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 9d67863..8738ac1 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -14,15 +14,13 @@ # limitations under the License. import logging import math -import warnings -from typing import * +from typing import Optional import torch import torch.fx import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet -from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) @@ -341,7 +339,7 @@ def forward( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) out = self.f(x1, x0, i0) - out = out.reshape(out.shape[0], -1) + out = out.reshape(out.shape[0], out.shape[1] * self.u) return out diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b50db4f..b3c0a20 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -88,6 +88,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ + if any(x.numel() == 0 for x in args): + use_fallback = True # Empty tensors are not supported by the CUDA kernel + if ( args and args[0].device.type == "cuda" @@ -285,7 +288,7 @@ def forward(self, *args): (math.prod(shape), arg.shape[-1]) ) if math.prod(arg.shape[:-1]) > 1 - else arg.reshape((1, arg.shape[-1])) + else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1])) ) for arg in args ] diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 2ff42f3..62ba30e 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -30,7 +30,8 @@ @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("original_mace", [True, False]) -def test_symmetric_contraction(dtype, layout, original_mace): +@pytest.mark.parametrize("batch", [0, 32]) +def test_symmetric_contraction(dtype, layout, original_mace, batch): mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") irreps_out = mul * cue.Irreps("O3", "0e + 1o") @@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace): original_mace=original_mace, ) - Z = 32 - x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda() - indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda() + x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda() + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda() out = m(x, indices) - assert out.shape == (Z, irreps_out.dim) + assert out.shape == (batch, irreps_out.dim) def from64(shape: tuple[int, ...], data: str) -> torch.Tensor: diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 155c73b..aa39e25 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -31,12 +31,14 @@ @pytest.mark.parametrize("irreps3", list_of_irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("batch", [0, 32]) def test_channel_wise( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps, layout: cue.IrrepsLayout, use_fallback: bool, + batch: int, ): m = cuet.ChannelWiseTensorProduct( irreps1, @@ -49,8 +51,8 @@ def test_channel_wise( dtype=torch.float64, ) - x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() + x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() + x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() out1 = m(x1, x2, use_fallback=use_fallback) diff --git a/docs/changelog.md b/docs/changelog.md index e1636f8..e1a8414 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,3 @@ -# Change Log +# Changelog ```{include} ../CHANGELOG.md From 0bfada92c66dd5e74ffe029fc459e64725b986ad Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 20 Nov 2024 07:17:40 -0800 Subject: [PATCH 02/44] trigger uniform 1d in test --- cuequivariance_torch/tests/operations/tp_channel_wise_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index aa39e25..64c08c3 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -20,7 +20,7 @@ from cuequivariance import descriptors list_of_irreps = [ - cue.Irreps("O3", "4x0e + 4x1o"), + cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"), ] From fd097c67e0ebae227987bda1352baa8af88fbbaa Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 00:05:48 -0800 Subject: [PATCH 03/44] satisfy linter Signed-off-by: Mario Geiger --- .../experimental/mace/symmetric_contractions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index 59c8e41..ebfc5c7 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cache -from typing import * +from typing import Optional import numpy as np @@ -164,7 +164,7 @@ def U_matrix_real( assert isinstance(ir_out, cue.Irrep) if correlation == 4: - filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) + filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) # noqa E741 else: filter_ir_mid = None From 251fc4d6146e1ffa55e9287b23e294441f873d8e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:19:43 -0800 Subject: [PATCH 04/44] from typing import --- .../cuequivariance_torch/primitives/tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index b3c0a20..86f04a0 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import * +from typing import Optional, OrderedDict, Tuple import torch import torch.fx From 3498a32a2e124b499aa4ed7cfe842e83492db91a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:24:46 -0800 Subject: [PATCH 05/44] determine math_dtype earlier --- .../primitives/tensor_product.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 86f04a0..affa7ce 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -47,6 +47,9 @@ def __init__( super().__init__() self.descriptor = descriptor + if math_dtype is None: + math_dtype = torch.get_default_dtype() + try: self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype) except NotImplementedError as e: @@ -116,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): def _tensor_product_fx( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, optimize_einsums: bool, ) -> torch.nn.Module: """ @@ -124,10 +127,6 @@ def _tensor_product_fx( - at least one input operand should have a batch dimension (ndim=2) - the output operand will have a batch dimension (ndim=2) """ - - if math_dtype is None: - math_dtype = torch.get_default_dtype() - descriptor = descriptor.remove_zero_paths() descriptor = descriptor.remove_empty_segments() @@ -313,7 +312,7 @@ def _sum(tensors, *, shape=None, like=None): def _tensor_product_cuda( descriptor: stp.SegmentedTensorProduct, device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, ) -> torch.nn.Module: logger.debug(f"Starting search for a cuda kernel for {descriptor}") @@ -326,9 +325,6 @@ def _tensor_product_cuda( f" Got {descriptor.subscripts}." ) - if math_dtype is None: - math_dtype = torch.get_default_dtype() - if not torch.cuda.is_available(): raise NotImplementedError("CUDA is not available.") From 7f3cf05c1fe200078385a7fc8ce555035a1f9ed7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:30:41 -0800 Subject: [PATCH 06/44] warning with pip commands --- .../cuequivariance_torch/primitives/tensor_product.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index affa7ce..4ac91af 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -57,6 +57,12 @@ def __init__( self.f_cuda = None except ImportError as e: logger.warning(f"CUDA implementation not available: {e}") + logger.warning( + "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" + "Install it with one of the following commands:\n" + "pip install cuequivariance-ops-torch-cu11\n" + "pip install cuequivariance-ops-torch-cu12" + ) self.f_cuda = None self.f_fx = _tensor_product_fx( From 262433557ca53b9d3a9dcae00bca8639b97f8d76 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:41:47 -0800 Subject: [PATCH 07/44] remove unused argument --- .../cuequivariance_torch/primitives/tensor_product.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 4ac91af..736c2aa 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -443,12 +443,10 @@ def forward( self, x0: torch.Tensor, x1: torch.Tensor, - b2: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1 = self._perm(x0, x1) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - assert b2 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) x0 = _reshape(x0, shape) @@ -504,13 +502,11 @@ def forward( x0: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor, - b3: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1, x2 = self._perm(x0, x1, x2) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - assert b3 is None shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) x0 = _reshape(x0, shape) From 91f7fce1457de45fe19547f329e0ceda86c0dd1a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 06:43:18 -0800 Subject: [PATCH 08/44] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb7a140..0c32d2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## Latest Changes + +- Add support for empty batch dimension in `cuequivariance-torch`. + ## 0.1.0 (2024-11-18) - Beta version of cuEquivariance released. From 4401048d23e2028a7ec5ea0c1717fffae292075a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:08:49 -0800 Subject: [PATCH 09/44] list of inputs --- .../cuequivariance_torch/operations/linear.py | 2 +- .../operations/rotation.py | 7 ++----- .../operations/spherical_harmonics.py | 2 +- .../operations/symmetric_contraction.py | 2 +- .../operations/tp_channel_wise.py | 2 +- .../operations/tp_fully_connected.py | 2 +- .../primitives/equivariant_tensor_product.py | 12 ++++++------ .../primitives/symmetric_tensor_product.py | 4 ++-- .../primitives/tensor_product.py | 18 +++++++++--------- .../tests/operations/tp_channel_wise_test.py | 2 +- .../operations/tp_fully_connected_test.py | 4 +--- .../equivariant_tensor_product_test.py | 10 +++++----- .../tests/primitives/tensor_product_test.py | 4 ++-- docs/tutorials/etp.rst | 2 +- docs/tutorials/stp.rst | 2 +- 15 files changed, 35 insertions(+), 40 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index e3e34da..197f5ac 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -126,4 +126,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x, use_fallback=use_fallback) + return self.f([weight, x], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index a9c13b8..9fb4d86 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -96,10 +96,7 @@ def forward( encodings_alpha = encode_rotation_angle(alpha, self.lmax) return self.f( - encodings_gamma, - encodings_beta, - encodings_alpha, - x, + [encodings_gamma, encodings_beta, encodings_alpha, x], use_fallback=use_fallback, ) @@ -194,4 +191,4 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the inversion layer.""" - return self.f(x) + return self.f([x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py index 7a8a6e2..bfd0163 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py @@ -55,6 +55,6 @@ def spherical_harmonics( math_dtype=x.dtype, optimize_fallback=optimize_fallback, ) - y = m(x) + y = m([x]) y = y.reshape(vectors.shape[:-1] + (y.shape[-1],)) return y diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 9c02c19..35f1b56 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -188,4 +188,4 @@ def forward( weight = self.weight weight = weight.flatten(1) - return self.f(weight, x, indices=indices, use_fallback=use_fallback) + return self.f([weight, x], indices=indices, use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index fcd3643..76402e7 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -147,4 +147,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x1, x2, use_fallback=use_fallback) + return self.f([weight, x1, x2], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 4f1dcf4..f33c7f6 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -148,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f(weight, x1, x2, use_fallback=use_fallback) + return self.f([weight, x1, x2], use_fallback=use_fallback) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 113fc39..18461b7 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import List, Optional, Union import torch @@ -41,7 +41,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> x1 = torch.ones(17, e.inputs[1].irreps.dim) >>> x2 = torch.ones(17, e.inputs[2].irreps.dim) >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul) - >>> tp(w, x1, x2) + >>> tp([w, x1, x2]) tensor([[0., 0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) @@ -50,7 +50,7 @@ class EquivariantTensorProduct(torch.nn.Module): >>> w = torch.ones(3, e.inputs[0].irreps.dim) >>> indices = torch.randint(3, (17,)) - >>> tp(w, x1, x2, indices=indices) + >>> tp([w, x1, x2], indices=indices) tensor([[0., 0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0., 0.]]) @@ -136,14 +136,14 @@ def extra_repr(self) -> str: def forward( self, - *inputs: torch.Tensor, + inputs: List[torch.Tensor], indices: Optional[torch.Tensor] = None, use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - inputs: list[torch.Tensor] = list(inputs) + inputs: List[torch.Tensor] = list(inputs) assert len(inputs) == len(self.etp.inputs) for a, b in zip(inputs, self.etp.inputs): @@ -162,7 +162,7 @@ def forward( # TODO: at some point we will have kernel for this assert len(inputs) >= 1 inputs[0] = inputs[0][indices] - output = self.tp(*inputs, use_fallback=use_fallback) + output = self.tp(inputs, use_fallback=use_fallback) if self.symm_tp is not None: if len(inputs) == 1: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 8738ac1..386d04a 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -109,7 +109,7 @@ def forward( use_fallback=use_fallback, ) if self.f0 is not None: - out += self.f0() + out += self.f0([]) return out @@ -368,6 +368,6 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f(x0[i0], *[x1] * (f.descriptor.num_operands - 2), use_fallback=True) + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2), use_fallback=True) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 736c2aa..644c612 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -15,7 +15,7 @@ import logging import math import warnings -from typing import Optional, OrderedDict, Tuple +from typing import List, Optional, OrderedDict, Tuple import torch import torch.fx @@ -76,12 +76,12 @@ def __repr__(self): ) return f"TensorProduct({self.descriptor} {has_cuda_kernel})" - def forward(self, *args, use_fallback: Optional[bool] = None): + def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = None): r""" Perform the tensor product based on the specified descriptor. Args: - args (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. + inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size of each operand as defined in the tensor product descriptor. use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available and the input @@ -97,16 +97,16 @@ def forward(self, *args, use_fallback: Optional[bool] = None): Raises: RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ - if any(x.numel() == 0 for x in args): + if any(x.numel() == 0 for x in inputs): use_fallback = True # Empty tensors are not supported by the CUDA kernel if ( - args - and args[0].device.type == "cuda" + inputs + and inputs[0].device.type == "cuda" and self.f_cuda is not None and (use_fallback is not True) ): - return self.f_cuda(*args) + return self.f_cuda(*inputs) if use_fallback is False: if self.f_cuda is not None: @@ -119,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None): "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." ) - return self.f_fx(*args) + return self.f_fx(inputs) def _tensor_product_fx( @@ -278,7 +278,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.module = module self.descriptor = descriptor - def forward(self, *args): + def forward(self, args): for oid, arg in enumerate(args): torch._assert( arg.shape[-1] == self.descriptor.operands[oid].size, diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 64c08c3..9540c73 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -62,7 +62,7 @@ def test_channel_wise( if layout == cue.mul_ir: d = d.add_or_transpose_modes("u,ui,j,uk+ijk") mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() - out2 = mfx(m.weight, x1, x2, use_fallback=True) + out2 = mfx([m.weight, x1, x2], use_fallback=True) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index d00e6ba..64944fb 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -59,9 +59,7 @@ def test_fully_connected( d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() out2 = mfx( - m.weight.to(torch.float64), - x1.to(torch.float64), - x2.to(torch.float64), + [m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], use_fallback=True, ).to(out1.dtype) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index aa1b0dd..0700a2c 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -84,11 +84,11 @@ def test_performance_cuda_vs_fx( ] for _ in range(10): - m(*inputs, use_fallback=False) - m(*inputs, use_fallback=True) + m(inputs, use_fallback=False) + m(inputs, use_fallback=True) def f(ufb: bool): - m(*inputs, use_fallback=ufb) + m(inputs, use_fallback=ufb) torch.cuda.synchronize() t0 = timeit.Timer(lambda: f(False)).timeit(number=10) @@ -130,7 +130,7 @@ def test_precision_cuda_vs_fx( device=device, math_dtype=math_dtype, ) - y0 = m(*inputs, use_fallback=False) + y0 = m(inputs, use_fallback=False) m = cuet.EquivariantTensorProduct( e, @@ -140,6 +140,6 @@ def test_precision_cuda_vs_fx( optimize_fallback=True, ) inputs = map(lambda x: x.to(torch.float64), inputs) - y1 = m(*inputs, use_fallback=True).to(dtype) + y1 = m(inputs, use_fallback=True).to(dtype) torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 1171f8b..53e8bfc 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -111,12 +111,12 @@ def test_primitive_tensor_product_cuda_vs_fx( m = cuet.TensorProduct( d, device=device, math_dtype=math_dtype, optimize_fallback=False ) - out1 = m(*inputs, use_fallback=False) + out1 = m(inputs, use_fallback=False) m = cuet.TensorProduct( d, device=device, math_dtype=torch.float64, optimize_fallback=False ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] - out2 = m(*inputs_, use_fallback=True) + out2 = m(inputs_, use_fallback=True) assert out1.shape[:-1] == torch.broadcast_shapes(*batches) assert out1.dtype == dtype diff --git a/docs/tutorials/etp.rst b/docs/tutorials/etp.rst index 1f2b20b..529eda8 100644 --- a/docs/tutorials/etp.rst +++ b/docs/tutorials/etp.rst @@ -94,6 +94,6 @@ We can execute an :class:`cuequivariance.EquivariantTensorProduct` with PyTorch. w = torch.randn(e.inputs[0].irreps.dim) x = torch.randn(e.inputs[1].irreps.dim) - module(w, x) + module([w, x]) Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout. diff --git a/docs/tutorials/stp.rst b/docs/tutorials/stp.rst index 2a098a1..b9516ee 100644 --- a/docs/tutorials/stp.rst +++ b/docs/tutorials/stp.rst @@ -112,7 +112,7 @@ Now we can execute the linear layer with random input and weight tensors. w = torch.randn(d.operands[0].size) x1 = torch.randn(3000, irreps1.dim) - x2 = linear_torch(w, x1) + x2 = linear_torch([w, x1]) assert x2.shape == (3000, irreps2.dim) From ad2db8d5a540ba2ca54020248395bd6ba6c821b1 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:12:23 -0800 Subject: [PATCH 10/44] add Fixed subtite --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c32d2c..0d75a9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ ## Latest Changes +### Fixed + - Add support for empty batch dimension in `cuequivariance-torch`. ## 0.1.0 (2024-11-18) From 889051ad16c421bc912cb6733b18665c72b91341 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 21 Nov 2024 07:13:37 -0800 Subject: [PATCH 11/44] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d75a9c..867796d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Latest Changes +### Changed + +- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. + ### Fixed - Add support for empty batch dimension in `cuequivariance-torch`. From bc6b405c7e23fd8ab18b2a935e4c4bfffdec0c88 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 3 Dec 2024 01:24:15 -0800 Subject: [PATCH 12/44] add test for torch.jit.script --- .../primitives/equivariant_tensor_product_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 110b3cf..d08af4b 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -154,3 +154,14 @@ def test_compile(): input1 = torch.randn(100, e.inputs[0].irreps.dim) input2 = torch.randn(100, e.inputs[1].irreps.dim) m_compile(input1, input2) + + +def test_script(): + e = cue.descriptors.symmetric_contraction( + cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] + ) + m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False) + m_script = torch.jit.script(m) + input1 = torch.randn(100, e.inputs[0].irreps.dim) + input2 = torch.randn(100, e.inputs[1].irreps.dim) + m_script(input1, input2) From c8de1858760b8112d3fde2766110b089c9e1bdf0 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 3 Dec 2024 01:43:47 -0800 Subject: [PATCH 13/44] fix --- .../primitives/equivariant_tensor_product.py | 4 ++++ .../primitives/symmetric_tensor_product.py | 2 +- .../tests/primitives/equivariant_tensor_product_test.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 2381dca..ed0bbd1 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -174,6 +174,10 @@ def forward( if len(inputs) == 2: [x0, x1] = inputs if indices is None: + torch._assert( + x0.ndim == 2, + f"Expected x0 to have shape (batch, dim), got {x0.shape}", + ) if x0.shape[0] == 1: indices = torch.zeros( (x1.shape[0],), dtype=torch.int32, device=x1.device diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 386d04a..79d92a6 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -201,7 +201,7 @@ def forward( torch._assert( x0.ndim == 2, - f"Expected 2 dims (i0.max() + 1, x0_size), got {x0.ndim}", + f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) shape = torch.broadcast_shapes(i0.shape, x1.shape[:-1]) i0 = i0.expand(shape).reshape((math.prod(shape),)) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 110b3cf..04d3aef 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -153,4 +153,4 @@ def test_compile(): m_compile = torch.compile(m, fullgraph=True) input1 = torch.randn(100, e.inputs[0].irreps.dim) input2 = torch.randn(100, e.inputs[1].irreps.dim) - m_compile(input1, input2) + m_compile([input1, input2]) From 16e4450b2acbeec082b8d6d4b9080c5388467298 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 3 Dec 2024 01:49:47 -0800 Subject: [PATCH 14/44] remove keyword-only and import in the forward --- .../primitives/transpose.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 8f6546f..4c40036 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import * +from typing import Optional import torch import torch.fx @@ -59,7 +59,7 @@ def __repr__(self): return f"TransposeIrrepsLayout({self.source} -> {self.target})" def forward( - self, x: torch.Tensor, *, use_fallback: Optional[bool] = None + self, x: torch.Tensor, use_fallback: Optional[bool] = None ) -> torch.Tensor: r""" Perform the transposition. @@ -92,7 +92,7 @@ def __init__( if info is not None: try: - import cuequivariance_ops_torch + import cuequivariance_ops_torch # noqa: F401 except ImportError: self.f_cuda = None else: @@ -104,10 +104,10 @@ def __init__( self.f = torch.nn.Identity() def __repr__(self): - return f"TransposeSegments()" + return "TransposeSegments()" def forward( - self, x: torch.Tensor, *, use_fallback: Optional[bool] = None + self, x: torch.Tensor, use_fallback: Optional[bool] = None ) -> torch.Tensor: """ Perform the transposition of the input tensor using either a CUDA kernel or a PyTorch fallback. @@ -184,12 +184,16 @@ def _transpose_info( return torch.IntTensor(info).to(device=device) +try: + from cuequivariance_ops_torch import segmented_transpose +except ImportError: + pass + + class _transpose(torch.nn.Module): def __init__(self, info: torch.IntTensor): super().__init__() self.register_buffer("_info", info, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - from cuequivariance_ops_torch import segmented_transpose - return segmented_transpose(x, self._info, True) From b2c4fbb8653fdceb5f60aedd2c73949483065741 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 4 Dec 2024 00:54:32 -0800 Subject: [PATCH 15/44] low lvl script tests --- .../primitives/symmetric_tensor_product.py | 8 +-- .../equivariant_tensor_product_test.py | 11 ---- .../tests/primitives/script_test.py | 66 +++++++++++++++++++ 3 files changed, 70 insertions(+), 15 deletions(-) create mode 100644 cuequivariance_torch/tests/primitives/script_test.py diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 79d92a6..1417b67 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -138,6 +138,9 @@ def __init__( ): super().__init__() + if math_dtype is None: + math_dtype = torch.get_default_dtype() + _check_descriptors(descriptors) self.descriptors = descriptors @@ -258,13 +261,10 @@ def __init__( self, stps: list[stp.SegmentedTensorProduct], device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, ): super().__init__() - if math_dtype is None: - math_dtype = torch.get_default_dtype() - max_degree = max(d.num_operands - 2 for d in stps) if max_degree > 6: raise NotImplementedError("Correlation > 6 is not implemented.") diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index e040687..04d3aef 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -154,14 +154,3 @@ def test_compile(): input1 = torch.randn(100, e.inputs[0].irreps.dim) input2 = torch.randn(100, e.inputs[1].irreps.dim) m_compile([input1, input2]) - - -def test_script(): - e = cue.descriptors.symmetric_contraction( - cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] - ) - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False) - m_script = torch.jit.script(m) - input1 = torch.randn(100, e.inputs[0].irreps.dim) - input2 = torch.randn(100, e.inputs[1].irreps.dim) - m_script([input1, input2]) diff --git a/cuequivariance_torch/tests/primitives/script_test.py b/cuequivariance_torch/tests/primitives/script_test.py new file mode 100644 index 0000000..8829041 --- /dev/null +++ b/cuequivariance_torch/tests/primitives/script_test.py @@ -0,0 +1,66 @@ +import torch + +import cuequivariance as cue +from cuequivariance_torch.primitives.symmetric_tensor_product import ( + CUDAKernel as SymmetricTensorProduct, +) +from cuequivariance_torch.primitives.tensor_product import ( + FusedTensorProductOp3, + TensorProductUniform3x1d, +) + + +def test_script_symmetric_contraction(): + ds = cue.descriptors.symmetric_contraction( + 32 * cue.Irreps("SO3", "0 + 1"), 32 * cue.Irreps("SO3", "0 + 1"), [1, 2, 3] + ).ds + + batch = 12 + x0 = torch.randn(3, ds[0].operands[0].size, device="cuda:0", dtype=torch.float32) + i0 = torch.zeros(batch, device="cuda:0", dtype=torch.int32) + x1 = torch.randn( + batch, ds[0].operands[1].size, device="cuda:0", dtype=torch.float32 + ) + + module = SymmetricTensorProduct(ds, torch.device("cuda:0"), torch.float32) + module = torch.jit.script(module) + + assert module(x0, i0, x1).shape == (batch, ds[0].operands[-1].size) + + +def test_script_fused_tp(): + d = ( + cue.descriptors.full_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) + + module = FusedTensorProductOp3(d, (0, 1), torch.device("cuda:0"), torch.float32) + module = torch.jit.script(module) + + assert module(x0, x1).shape == (batch, d.operands[2].size) + + +def test_script_uniform_tp(): + d = ( + cue.descriptors.full_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) + + module = TensorProductUniform3x1d(d, torch.device("cuda:0"), torch.float32) + module = torch.jit.script(module) + + assert module(x0, x1).shape == (batch, d.operands[2].size) From 4669a86f8da4cbe3bc0cd328731deee1ba9bd237 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 4 Dec 2024 01:07:19 -0800 Subject: [PATCH 16/44] TensorProduct working with script() Signed-off-by: Boris Fomitchev --- .../primitives/equivariant_tensor_product.py | 1 - .../primitives/symmetric_tensor_product.py | 8 +- .../primitives/tensor_product.py | 163 ++++++++++++------ .../tests/primitives/tensor_product_test.py | 7 +- 4 files changed, 118 insertions(+), 61 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index ed0bbd1..4c540b6 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -145,7 +145,6 @@ def forward( """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - inputs: List[torch.Tensor] = list(inputs) assert len(inputs) == len(self.etp.inputs) for a, dim in zip(inputs, self.operands_dims): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 79d92a6..7011d71 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -203,7 +203,7 @@ def forward( x0.ndim == 2, f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) - shape = torch.broadcast_shapes(i0.shape, x1.shape[:-1]) + shape = broadcast_shapes(i0.shape, x1.shape[:-1]) i0 = i0.expand(shape).reshape((math.prod(shape),)) x1 = x1.expand(shape + (x1.shape[-1],)).reshape( (math.prod(shape), x1.shape[-1]) @@ -335,9 +335,9 @@ def forward( i0 = i0.to(torch.int32) x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u) x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u) - logger.debug( - f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" - ) + # logger.debug( + # f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" + # ) out = self.f(x1, x0, i0) out = out.reshape(out.shape[0], out.shape[1] * self.u) return out diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 1dceb9e..10b7080 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -19,11 +19,48 @@ import torch import torch.fx - +from torch.jit import Final from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) +def prod(numbers: List[int]): + product = 1 + for num in numbers: + product *= num + return product + +def broadcast_shapes(shapes: List[List[int]]): + if torch.jit.is_scripting(): + max_len = 0 + for shape in shapes: + if isinstance(shape, int): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (tuple, list)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + for shape in shapes: + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, (tuple, list)): + for i in range(-1, -1 - len(shape), -1): + if shape[i] < 0: + raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})" + .format(shape[i], shape[i])) + if shape[i] == 1 or shape[i] == result[i]: + continue + if result[i] != 1: + raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape") + result[i] = shape[i] + else: + raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape) + return torch.Size(result) + else: + return torch.functional.broadcast_shapes(*shapes) + class TensorProduct(torch.nn.Module): """ @@ -36,25 +73,30 @@ class TensorProduct(torch.nn.Module): optimize_fallback (bool, optional): If `True`, the fallback method is optimized. If `False`, the fallback method is used without optimization. """ + num_operands: Final[int] + def __init__( self, descriptor: stp.SegmentedTensorProduct, *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() self.descriptor = descriptor - + # for script() + self.num_operands = descriptor.num_operands if math_dtype is None: math_dtype = torch.get_default_dtype() try: - self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype) + self.f_cuda3, self.f_cuda4 = _tensor_product_cuda(descriptor, device, math_dtype) except NotImplementedError as e: logger.info(f"CUDA implementation not available: {e}") - self.f_cuda = None + self.f_cuda3 = None + self.f_cuda4 = None except ImportError as e: logger.warning(f"CUDA implementation not available: {e}") logger.warning( @@ -63,16 +105,20 @@ def __init__( "pip install cuequivariance-ops-torch-cu11\n" "pip install cuequivariance-ops-torch-cu12" ) - self.f_cuda = None + self.f_cuda3 = None + self.f_cuda4 = None - self.f_fx = _tensor_product_fx( - descriptor, device, math_dtype, optimize_fallback is True - ) + if use_fallback == True: + self.f_fx = _tensor_product_fx( + descriptor, device, math_dtype, optimize_fallback is True + ) + else: + self.f_fx = None self._optimize_fallback = optimize_fallback def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.f_cuda is not None else "(without CUDA kernel)" + "(with CUDA kernel)" if self.f_cuda3 is not None or self.f_cuda4 is not None else "(without CUDA kernel)" ) return f"TensorProduct({self.descriptor} {has_cuda_kernel})" @@ -103,13 +149,15 @@ def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = Non if ( inputs and inputs[0].device.type == "cuda" - and self.f_cuda is not None and (use_fallback is not True) ): - return self.f_cuda(*inputs) + if self.f_cuda3 is not None: + return self.f_cuda3(inputs[0], inputs[1]) + else: + return self.f_cuda4(inputs[0], inputs[1], inputs[2]) if use_fallback is False: - if self.f_cuda is not None: + if self.f_cuda3 is not None and self.f_cuda4 is not None: raise RuntimeError("CUDA kernel available but input is not on CUDA") else: raise RuntimeError("No CUDA kernel available") @@ -119,6 +167,8 @@ def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = Non "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." ) + if self.f_fx is None: + raise RuntimeError("No fallback method available") return self.f_fx(inputs) @@ -190,7 +240,7 @@ def _tensor_product_fx( seg_shape = descriptor.get_segment_shape(-1, path) outputs += [ out.reshape( - out.shape[: out.ndim - len(seg_shape)] + (math.prod(seg_shape),) + out.shape[: out.ndim - len(seg_shape)] + (prod(seg_shape),) ) ] @@ -206,7 +256,7 @@ def _tensor_product_fx( for out, path in zip(outputs, descriptor.paths) if path.indices[-1] == i ], - shape=batch_shape + (math.prod(descriptor.operands[-1][i]),), + shape=batch_shape + (prod(descriptor.operands[-1][i]),), like=outputs[0], ) for i in range(descriptor.operands[-1].num_segments) @@ -252,7 +302,7 @@ def __init__(self, descriptor: stp.SegmentedTensorProduct): ) def forward(self, *args): - shape = torch.broadcast_shapes(*[arg.shape[:-1] for arg in args]) + shape = broadcast_shapes([arg.shape[:-1] for arg in args]) output = torch.zeros( shape + (descriptor.operands[-1].size,), device=device, @@ -278,22 +328,23 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.module = module self.descriptor = descriptor - def forward(self, args): - for oid, arg in enumerate(args): - torch._assert( - arg.shape[-1] == self.descriptor.operands[oid].size, - "input shape[-1] does not match operand size", - ) + def forward(self, args:List[torch.Tensor]): + if not torch.jit.is_scripting(): + for oid, arg in enumerate(args): + torch._assert( + arg.shape[-1] == self.descriptor.operands[oid].size, + "input shape[-1] does not match operand size", + ) - shape = torch.broadcast_shapes(*[arg.shape[:-1] for arg in args]) + shape = broadcast_shapes([arg.shape[:-1] for arg in args]) args = [ ( arg.expand(shape + (arg.shape[-1],)).reshape( - (math.prod(shape), arg.shape[-1]) + (prod(shape), arg.shape[-1]) ) - if math.prod(arg.shape[:-1]) > 1 - else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1])) + if prod(arg.shape[:-1]) > 1 + else arg.reshape((prod(arg.shape[:-1]), arg.shape[-1])) ) for arg in args ] @@ -353,9 +404,9 @@ def _tensor_product_cuda( operand_num_segments=[o.num_segments for o in d.operands], ): if descriptor.num_operands == 3: - return TensorProductUniform3x1d(d, device, math_dtype) + return TensorProductUniform3x1d(d, device, math_dtype), None else: - return TensorProductUniform4x1d(d, device, math_dtype) + return None, TensorProductUniform4x1d(d, device, math_dtype) supported_targets = [ stp.Subscripts(subscripts) @@ -385,18 +436,18 @@ def _tensor_product_cuda( ) if descriptor.num_operands == 3: - return FusedTensorProductOp3(descriptor, perm[:2], device, math_dtype) + return FusedTensorProductOp3(descriptor, perm[:2], device, math_dtype), None elif descriptor.num_operands == 4: - return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) - + return None, FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) -def _reshape(x: torch.Tensor, leading_shape: tuple[int, ...]) -> torch.Tensor: + +def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor: # Make x have shape (Z, x.shape[-1]) or (x.shape[-1],) - if math.prod(leading_shape) > 1 and math.prod(x.shape[:-1]) == 1: + if prod(leading_shape) > 1 and prod(x.shape[:-1]) == 1: return x.reshape((x.shape[-1],)) else: return x.expand(leading_shape + (x.shape[-1],)).reshape( - (math.prod(leading_shape), x.shape[-1]) + (prod(leading_shape), x.shape[-1]) ) @@ -434,7 +485,7 @@ def __init__( ).to(device=device) def __repr__(self) -> str: - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"FusedTensorProductOp3({self.descriptor} (output last operand))" def forward( self, @@ -445,13 +496,14 @@ def forward( assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) - logger.debug( - f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" - ) + if not torch.jit.is_scripting(): + logger.debug( + f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" + ) out = self._f(x0, x1) @@ -492,7 +544,7 @@ def __init__( ).to(device=device) def __repr__(self) -> str: - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"FusedTensorProductOp4({self.descriptor} (output last operand))" def forward( self, @@ -505,14 +557,15 @@ def forward( assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) - logger.debug( - f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" - ) + if not torch.jit.is_scripting(): + logger.debug( + f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" + ) out = self._f(x0, x1, x2) @@ -546,13 +599,13 @@ def __init__( ).to(device=device) def __repr__(self): - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" - def forward(self, x0, x1): + def forward(self, x0:torch.Tensor, x1:torch.Tensor): assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) @@ -561,9 +614,10 @@ def forward(self, x0, x1): if x1.ndim == 1: x1 = x1.unsqueeze(0) - logger.debug( - f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" - ) + if not torch.jit.is_scripting(): + logger.debug( + f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" + ) out = self._f(x0, x1) @@ -597,14 +651,14 @@ def __init__( ).to(device=device) def __repr__(self): - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"TensorProductUniform4x1d({self.descriptor} (output last operand))" def forward(self, x0, x1, x2): assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) @@ -616,9 +670,10 @@ def forward(self, x0, x1, x2): if x2.ndim == 1: x2 = x2.unsqueeze(0) - logger.debug( - f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" - ) + if not torch.jit.is_scripting(): + logger.debug( + f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" + ) out = self._f(x0, x1, x2) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 53e8bfc..dc6c2e9 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -111,9 +111,11 @@ def test_primitive_tensor_product_cuda_vs_fx( m = cuet.TensorProduct( d, device=device, math_dtype=math_dtype, optimize_fallback=False ) - out1 = m(inputs, use_fallback=False) + m = torch.jit.script(m) + out1 = m(inputs) + m = cuet.TensorProduct( - d, device=device, math_dtype=torch.float64, optimize_fallback=False + d, device=device, math_dtype=torch.float64, use_fallback=True, optimize_fallback=False ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] out2 = m(inputs_, use_fallback=True) @@ -134,3 +136,4 @@ def test_primitive_tensor_product_cuda_vs_fx( for g1, g2 in zip(double_grad1, double_grad2): torch.testing.assert_close(g1, g2.to(dtype), atol=100 * tol, rtol=100 * tol) + From dc9d5b0e77164ad8e968742deac70fccdfa1fe8c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 4 Dec 2024 01:26:41 -0800 Subject: [PATCH 17/44] add 4 operands tests --- .../tests/primitives/script_test.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/tests/primitives/script_test.py b/cuequivariance_torch/tests/primitives/script_test.py index 8829041..44e880f 100644 --- a/cuequivariance_torch/tests/primitives/script_test.py +++ b/cuequivariance_torch/tests/primitives/script_test.py @@ -6,7 +6,9 @@ ) from cuequivariance_torch.primitives.tensor_product import ( FusedTensorProductOp3, + FusedTensorProductOp4, TensorProductUniform3x1d, + TensorProductUniform4x1d, ) @@ -28,7 +30,7 @@ def test_script_symmetric_contraction(): assert module(x0, i0, x1).shape == (batch, ds[0].operands[-1].size) -def test_script_fused_tp(): +def test_script_fused_tp_3(): d = ( cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") @@ -47,7 +49,28 @@ def test_script_fused_tp(): assert module(x0, x1).shape == (batch, d.operands[2].size) -def test_script_uniform_tp(): +def test_script_fused_tp_4(): + d = ( + cue.descriptors.fully_connected_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + .permute_operands([1, 2, 0, 3]) + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) + x2 = torch.randn(batch, d.operands[2].size, device="cuda:0", dtype=torch.float32) + + module = FusedTensorProductOp4(d, (0, 1, 2), torch.device("cuda:0"), torch.float32) + module = torch.jit.script(module) + + assert module(x0, x1, x2).shape == (batch, d.operands[3].size) + + +def test_script_uniform_tp_3(): d = ( cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") @@ -64,3 +87,23 @@ def test_script_uniform_tp(): module = torch.jit.script(module) assert module(x0, x1).shape == (batch, d.operands[2].size) + + +def test_script_uniform_tp_4(): + d = ( + cue.descriptors.channelwise_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) + x2 = torch.randn(batch, d.operands[2].size, device="cuda:0", dtype=torch.float32) + + module = TensorProductUniform4x1d(d, torch.device("cuda:0"), torch.float32) + module = torch.jit.script(module) + + assert module(x0, x1, x2).shape == (batch, d.operands[3].size) From 334b4604dd506537155de9bb49d4482ef2e4c2a8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 4 Dec 2024 20:12:41 -0800 Subject: [PATCH 18/44] Unit tests run Signed-off-by: Boris Fomitchev --- .../layers/tp_conv_fully_connected.py | 6 + .../cuequivariance_torch/operations/linear.py | 13 +- .../operations/rotation.py | 17 +- .../operations/spherical_harmonics.py | 6 + .../operations/symmetric_contraction.py | 13 +- .../operations/tp_channel_wise.py | 13 +- .../operations/tp_fully_connected.py | 12 +- .../primitives/equivariant_tensor_product.py | 152 +++++++++++------- .../primitives/symmetric_tensor_product.py | 81 ++++------ .../primitives/tensor_product.py | 125 ++++++-------- .../primitives/transpose.py | 43 ++--- .../equivariant_tensor_product_test.py | 48 ++++-- .../symmetric_tensor_product_test.py | 15 +- .../tests/primitives/tensor_product_test.py | 2 +- .../tests/primitives/transpose_test.py | 4 +- 15 files changed, 286 insertions(+), 264 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index 3d842ea..e4b3a58 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -56,6 +56,10 @@ class FullyConnectedTensorProductConv(nn.Module): mlp_channels (Sequence of int, optional): A sequence of integers defining the number of neurons in each layer in MLP before the output layer. If None, no MLP will be added. The input layer contains edge embeddings and node scalar features. Defaults to None. mlp_activation (``nn.Module`` or Sequence of ``nn.Module``, optional): A sequence of functions to be applied in between linear layers in MLP, e.g., ``nn.Sequential(nn.ReLU(), nn.Dropout(0.4))``. Defaults to ``nn.GELU()``. layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Examples: >>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o") @@ -121,6 +125,7 @@ def __init__( mlp_channels: Optional[Sequence[int]] = None, mlp_activation: Union[nn.Module, Sequence[nn.Module], None] = nn.GELU(), layout: cue.IrrepsLayout = None, # e3nn_compat_mode + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -141,6 +146,7 @@ def __init__( out_irreps, layout=self.layout, shared_weights=False, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 197f5ac..977ff2d 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -32,6 +32,10 @@ class Linear(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the irreducible representations, by default ``cue.mul_ir``. This is the layout used in the e3nn library. shared_weights (bool, optional): Whether to use shared weights, by default True. internal_weights (bool, optional): Whether to use internal weights, by default True if shared_weights is True, otherwise False. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. """ def __init__( @@ -47,6 +51,7 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -84,6 +89,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -94,8 +100,6 @@ def forward( self, x: torch.Tensor, weight: Optional[torch.Tensor] = None, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Forward pass of the linear layer. @@ -103,9 +107,6 @@ def forward( Args: x (torch.Tensor): The input tensor. weight (torch.Tensor, optional): The weight tensor. If None, the internal weight tensor is used. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: The output tensor after applying the linear transformation. @@ -126,4 +127,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f([weight, x], use_fallback=use_fallback) + return self.f([weight, x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index 9fb4d86..9c03468 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -40,6 +40,7 @@ def __init__( layout_out: Optional[cue.IrrepsLayout] = None, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -60,6 +61,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -69,8 +71,6 @@ def forward( beta: torch.Tensor, alpha: torch.Tensor, x: torch.Tensor, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Forward pass of the rotation layer. @@ -80,9 +80,6 @@ def forward( beta (torch.Tensor): The beta angles. Second rotation around the x-axis. alpha (torch.Tensor): The alpha angles. Third rotation around the y-axis. x (torch.Tensor): The input tensor. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: The rotated tensor. @@ -97,7 +94,6 @@ def forward( return self.f( [encodings_gamma, encodings_beta, encodings_alpha, x], - use_fallback=use_fallback, ) @@ -159,6 +155,11 @@ class Inversion(torch.nn.Module): Args: irreps (Irreps): The irreducible representations of the tensor to invert. layout (IrrepsLayout, optional): The memory layout of the tensor, ``cue.ir_mul`` is preferred. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. """ def __init__( @@ -170,6 +171,8 @@ def __init__( layout_out: Optional[cue.IrrepsLayout] = None, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, + optimize_fallback: Optional[bool] = None, ): super().__init__() (irreps,) = default_irreps(irreps) @@ -187,6 +190,8 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py index bfd0163..b2ecc93 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py @@ -25,6 +25,7 @@ def spherical_harmonics( ls: list[int], vectors: torch.Tensor, normalize: bool = True, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ) -> torch.Tensor: r"""Compute the spherical harmonics of the input vectors. @@ -33,6 +34,10 @@ def spherical_harmonics( ls (list of int): List of spherical harmonic degrees. vectors (torch.Tensor): Input vectors of shape (..., 3). normalize (bool, optional): Whether to normalize the input vectors. Defaults to True. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Returns: @@ -53,6 +58,7 @@ def spherical_harmonics( layout=cue.ir_mul, device=x.device, math_dtype=x.dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) y = m([x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 35f1b56..fac5739 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -38,6 +38,10 @@ class SymmetricContraction(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the input and output irreps. If not provided, a default layout is used. math_dtype (torch.dtype, optional): The data type for mathematical operations. If not specified, the default data type from the torch environment is used. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Examples: >>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o") @@ -102,6 +106,7 @@ def __init__( dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, original_mace: bool = False, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -147,6 +152,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype or dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -160,8 +166,6 @@ def forward( self, x: torch.Tensor, indices: torch.Tensor, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Perform the forward pass of the symmetric contraction operation. @@ -170,9 +174,6 @@ def forward( x (torch.Tensor): The input tensor. It should have shape (..., irreps_in.dim). indices (torch.Tensor): The index of the weight to use for each batch element. It should have shape (...). - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: The output tensor. It has shape (batch, irreps_out.dim). @@ -188,4 +189,4 @@ def forward( weight = self.weight weight = weight.flatten(1) - return self.f([weight, x], indices=indices, use_fallback=use_fallback) + return self.f([weight, x], indices=indices) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index 76402e7..169a248 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -33,6 +33,10 @@ class ChannelWiseTensorProduct(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn. shared_weights (bool, optional): Whether to share weights across the batch dimension. Default is True. internal_weights (bool, optional): Whether to create module parameters for weights. Default is None. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Note: In e3nn there was a irrep_normalization and path_normalization parameters. @@ -54,6 +58,7 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -95,6 +100,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -110,8 +116,6 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, weight: Optional[torch.Tensor] = None, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Perform the forward pass of the fully connected tensor product operation. @@ -122,9 +126,6 @@ def forward( weight (torch.Tensor, optional): Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: @@ -147,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f([weight, x1, x2], use_fallback=use_fallback) + return self.f([weight, x1, x2]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index f33c7f6..fd7706f 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -33,6 +33,10 @@ class FullyConnectedTensorProduct(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn. shared_weights (bool, optional): Whether to share weights across the batch dimension. Default is True. internal_weights (bool, optional): Whether to create module parameters for weights. Default is None. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Note: In e3nn there was a irrep_normalization and path_normalization parameters. @@ -54,6 +58,7 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -111,8 +116,6 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, weight: Optional[torch.Tensor] = None, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Perform the forward pass of the fully connected tensor product operation. @@ -123,9 +126,6 @@ def forward( weight (torch.Tensor, optional): Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: @@ -148,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f([weight, x1, x2], use_fallback=use_fallback) + return self.f([weight, x1, x2]) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 4c540b6..06ea215 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -21,6 +21,59 @@ from cuequivariance.irreps_array.misc_ui import default_layout +class Dispatcher(torch.nn.Module): + def __init__(self, tp): + super().__init__() + self.tp = tp + + +class TPDispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if indices is not None: + # TODO: at some point we will have kernel for this + assert len(inputs) >= 1 + inputs[0] = inputs[0][indices] + return self.tp(inputs) + + +class SymmetricTPDispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert indices is None + return self.tp(inputs[0]) + +class IWeightedSymmetricTPDispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x0 = inputs[0] + x1 = inputs[1] + if indices is None: + torch._assert( + x0.ndim == 2, + f"Expected x0 to have shape (batch, dim), got {x0.shape}", + ) + if x0.shape[0] == 1: + indices = torch.zeros( + (x1.shape[0],), dtype=torch.int32, device=x1.device + ) + else: # x0.shape[0] == x1.shape[0]: + indices = torch.arange( + x1.shape[0], dtype=torch.int32, device=x1.device + ) + # borisf : why was it here ? + # if indices is not None: + return self.tp(x0, indices, x1) + class EquivariantTensorProduct(torch.nn.Module): r"""Equivariant tensor product. @@ -31,7 +84,10 @@ class EquivariantTensorProduct(torch.nn.Module): layout_out (IrrepsLayout): layout for output. device (torch.device): device of the Module. math_dtype (torch.dtype): dtype for internal computations. + use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. optimize_fallback (bool): whether to optimize the fallback implementation. + Raises: + RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. Examples: >>> e = cue.descriptors.fully_connected_tensor_product( @@ -55,7 +111,7 @@ class EquivariantTensorProduct(torch.nn.Module): ... [0., 0., 0., 0., 0., 0.]]) """ - + def __init__( self, e: cue.EquivariantTensorProduct, @@ -67,6 +123,7 @@ def __init__( layout_out: Optional[cue.IrrepsLayout] = None, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -92,6 +149,7 @@ def __init__( source=layout_used, target=input_expected.layout, device=device, + use_fallback = use_fallback ) ) self.transpose_out = cuet.TransposeIrrepsLayout( @@ -99,37 +157,42 @@ def __init__( source=e.output.layout, target=layout_out, device=device, + use_fallback = use_fallback ) if any(d.num_operands != e.num_inputs + 1 for d in e.ds): - self.tp = None - if e.num_inputs == 1: - self.symm_tp = cuet.SymmetricTensorProduct( - e.ds, - device=device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + self.tp = SymmetricTPDispatcher( + cuet.SymmetricTensorProduct( + e.ds, + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, + ) ) elif e.num_inputs == 2: - self.symm_tp = cuet.IWeightedSymmetricTensorProduct( - e.ds, - device=device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + self.tp = IWeightedSymmetricTPDispatcher( + cuet.IWeightedSymmetricTensorProduct( + e.ds, + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, + ) ) else: raise NotImplementedError("This should not happen") else: - [d] = e.ds - - self.tp = cuet.TensorProduct( - d, - device=device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + self.tp = TPDispatcher( + cuet.TensorProduct( + e.ds[0], + device=device, + math_dtype=math_dtype, + use_fallback = use_fallback, + optimize_fallback=optimize_fallback, + ) ) - self.symm_tp = None self.operands_dims = [op.irreps.dim for op in e.operands] @@ -140,59 +203,24 @@ def forward( self, inputs: List[torch.Tensor], indices: Optional[torch.Tensor] = None, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - assert len(inputs) == len(self.etp.inputs) + # assert len(inputs) == len(self.etp.inputs) for a, dim in zip(inputs, self.operands_dims): assert a.shape[-1] == dim # Transpose inputs - inputs = [ - t(a, use_fallback=use_fallback) for t, a in zip(self.transpose_in, inputs) - ] + inputs[0] = self.transpose_in[0](inputs[0]) + if len(self.transpose_in) > 1: + inputs[1] = self.transpose_in[1](inputs[1]) # Compute tensor product - output = None - - if self.tp is not None: - if indices is not None: - # TODO: at some point we will have kernel for this - assert len(inputs) >= 1 - inputs[0] = inputs[0][indices] - output = self.tp(inputs, use_fallback=use_fallback) - - if self.symm_tp is not None: - if len(inputs) == 1: - assert indices is None - output = self.symm_tp(inputs[0], use_fallback=use_fallback) - - if len(inputs) == 2: - [x0, x1] = inputs - if indices is None: - torch._assert( - x0.ndim == 2, - f"Expected x0 to have shape (batch, dim), got {x0.shape}", - ) - if x0.shape[0] == 1: - indices = torch.zeros( - (x1.shape[0],), dtype=torch.int32, device=x1.device - ) - elif x0.shape[0] == x1.shape[0]: - indices = torch.arange( - x1.shape[0], dtype=torch.int32, device=x1.device - ) - - if indices is not None: - output = self.symm_tp(x0, indices, x1, use_fallback=use_fallback) - - if output is None: - raise NotImplementedError("This should not happen") + output = self.tp(inputs, indices) # Transpose output - output = self.transpose_out(output, use_fallback=use_fallback) + output = self.transpose_out(output) return output diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 7011d71..8f8ea0d 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -21,6 +21,7 @@ import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet +from cuequivariance_torch.primitives.tensor_product import broadcast_shapes, prod logger = logging.getLogger(__name__) @@ -41,6 +42,7 @@ def __init__( *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -55,6 +57,7 @@ def __init__( d0, device=device, math_dtype=math_dtype, + use_fallback = use_fallback, optimize_fallback=optimize_fallback, ) else: @@ -86,7 +89,7 @@ def __init__( ) def forward( - self, x0: torch.Tensor, use_fallback: Optional[bool] = None + self, x0: torch.Tensor ) -> torch.Tensor: r""" Perform the forward pass of the indexed symmetric tensor product operation. @@ -105,8 +108,7 @@ def forward( out = self.f( torch.ones((1, 1), dtype=x0.dtype, device=x0.device), torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device), - x0, - use_fallback=use_fallback, + x0 ) if self.f0 is not None: out += self.f0([]) @@ -134,6 +136,7 @@ def __init__( *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -141,30 +144,35 @@ def __init__( _check_descriptors(descriptors) self.descriptors = descriptors - try: - self.f_cuda = CUDAKernel(descriptors, device, math_dtype) - except NotImplementedError as e: - logger.info(f"Failed to initialize CUDA implementation: {e}") - self.f_cuda = None - except ImportError as e: - logger.warning(f"Failed to initialize CUDA implementation: {e}") - self.f_cuda = None - - self.f_fx = FallbackImpl( - descriptors, - device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, - ) - d = next(d for d in descriptors if d.num_operands >= 3) self.x0_size = d.operands[0].size self.x1_size = d.operands[1].size self.x2_size = d.operands[-1].size + self.has_cuda = False + + if not use_fallback == True: + try: + self.f = CUDAKernel(descriptors, device, math_dtype) + self.has_cuda = True + return + except NotImplementedError as e: + logger.info(f"Failed to initialize CUDA implementation: {e}") + except ImportError as e: + logger.warning(f"Failed to initialize CUDA implementation: {e}") + + if use_fallback == False: + raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available") + else: + self.f = FallbackImpl( + descriptors, + device, + math_dtype=math_dtype, + optimize_fallback=optimize_fallback, + ) def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.f_cuda is not None else "(without CUDA kernel)" + "(with CUDA kernel)" if self.has_cuda is not None else "(without CUDA kernel)" ) return f"IWeightedSymmetricTensorProduct({has_cuda_kernel})" @@ -173,7 +181,6 @@ def forward( x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: r""" Perform the forward pass of the indexed symmetric tensor product operation. @@ -187,10 +194,6 @@ def forward( The index tensor for the first operand. It should have the shape (...). x1 : torch.Tensor The repeated input tensor. It should have the shape (..., x1_size). - use_fallback : Optional[bool], optional - If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns ------- @@ -203,32 +206,18 @@ def forward( x0.ndim == 2, f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) - shape = broadcast_shapes(i0.shape, x1.shape[:-1]) - i0 = i0.expand(shape).reshape((math.prod(shape),)) + shape = broadcast_shapes([i0.shape, x1.shape[:-1]]) + i0 = i0.expand(shape).reshape((prod(shape),)) x1 = x1.expand(shape + (x1.shape[-1],)).reshape( - (math.prod(shape), x1.shape[-1]) + (prod(shape), x1.shape[-1]) ) - - if ( - x0.device.type == "cuda" - and self.f_cuda is not None - and (use_fallback is not True) - ): - out = self.f_cuda(x0, i0, x1) - out = out.reshape(shape + (self.x2_size,)) - return out - - if use_fallback is False: - if self.f_cuda is not None: - raise RuntimeError("CUDA kernel available but input is not on CUDA") - else: - raise RuntimeError("No CUDA kernel available") - - out = self.f_fx(x0, i0, x1) + + out = self.f(x0, i0, x1) out = out.reshape(shape + (self.x2_size,)) return out + def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): if len(descriptors) == 0: raise ValueError("stps must contain at least one STP.") @@ -368,6 +357,6 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2), use_fallback=True) + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 10b7080..941b34c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -70,11 +70,13 @@ class TensorProduct(torch.nn.Module): descriptor (SegmentedTensorProduct): The descriptor of the segmented tensor product. math_dtype (torch.dtype, optional): The data type of the coefficients and calculations. device (torch.device, optional): The device on which the calculations are performed. - optimize_fallback (bool, optional): If `True`, the fallback method is optimized. If `False`, the fallback method is used without optimization. - """ + use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. - num_operands: Final[int] + optimize_fallback (bool, optional): If `True`, the fallback method is optimized. If `False`, the fallback method is used without optimization. + Raises: + RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. + """ def __init__( self, descriptor: stp.SegmentedTensorProduct, @@ -86,43 +88,47 @@ def __init__( ): super().__init__() self.descriptor = descriptor - # for script() - self.num_operands = descriptor.num_operands if math_dtype is None: math_dtype = torch.get_default_dtype() - - try: - self.f_cuda3, self.f_cuda4 = _tensor_product_cuda(descriptor, device, math_dtype) - except NotImplementedError as e: - logger.info(f"CUDA implementation not available: {e}") - self.f_cuda3 = None - self.f_cuda4 = None - except ImportError as e: - logger.warning(f"CUDA implementation not available: {e}") - logger.warning( - "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" - "Install it with one of the following commands:\n" - "pip install cuequivariance-ops-torch-cu11\n" - "pip install cuequivariance-ops-torch-cu12" - ) - self.f_cuda3 = None - self.f_cuda4 = None - - if use_fallback == True: - self.f_fx = _tensor_product_fx( + self.f = None + self.has_cuda = False + + if not use_fallback == True: + try: + self.f = _tensor_product_cuda(descriptor, device, math_dtype) + self.has_cuda = True + return + except NotImplementedError as e: + logger.info(f"CUDA implementation not available: {e}") + except ImportError as e: + logger.warning(f"CUDA implementation not available: {e}") + logger.warning( + "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" + "Install it with one of the following commands:\n" + "pip install cuequivariance-ops-torch-cu11\n" + "pip install cuequivariance-ops-torch-cu12" + ) + + if use_fallback == False: + raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available!") + else: + self.f = _tensor_product_fx( descriptor, device, math_dtype, optimize_fallback is True ) - else: - self.f_fx = None - self._optimize_fallback = optimize_fallback + if optimize_fallback is None: + warnings.warn( + "The fallback method is used but it has not been optimized. " + "Consider setting optimize_fallback=True when creating the TensorProduct module." + ) + self._optimize_fallback = optimize_fallback def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.f_cuda3 is not None or self.f_cuda4 is not None else "(without CUDA kernel)" + "(with CUDA kernel)" if self.has_cuda else "(without CUDA kernel)" ) return f"TensorProduct({self.descriptor} {has_cuda_kernel})" - def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = None): + def forward(self, inputs: List[torch.Tensor]): r""" Perform the tensor product based on the specified descriptor. @@ -130,9 +136,6 @@ def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = Non inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size of each operand as defined in the tensor product descriptor. - use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available and the input - is on CUDA. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available or the - input is not on CUDA. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: @@ -140,36 +143,11 @@ def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = Non It has a shape of (batch, last_operand_size), where `last_operand_size` is the size of the last operand in the descriptor. - Raises: - RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ - if any(x.numel() == 0 for x in inputs): - use_fallback = True # Empty tensors are not supported by the CUDA kernel - - if ( - inputs - and inputs[0].device.type == "cuda" - and (use_fallback is not True) - ): - if self.f_cuda3 is not None: - return self.f_cuda3(inputs[0], inputs[1]) - else: - return self.f_cuda4(inputs[0], inputs[1], inputs[2]) + # if any(x.numel() == 0 for x in inputs): + # use_fallback = True # Empty tensors are not supported by the CUDA kernel - if use_fallback is False: - if self.f_cuda3 is not None and self.f_cuda4 is not None: - raise RuntimeError("CUDA kernel available but input is not on CUDA") - else: - raise RuntimeError("No CUDA kernel available") - - if self._optimize_fallback is None: - warnings.warn( - "The fallback method is used but it has not been optimized. " - "Consider setting optimize_fallback=True when creating the TensorProduct module." - ) - if self.f_fx is None: - raise RuntimeError("No fallback method available") - return self.f_fx(inputs) + return self.f(inputs) def _tensor_product_fx( @@ -404,9 +382,9 @@ def _tensor_product_cuda( operand_num_segments=[o.num_segments for o in d.operands], ): if descriptor.num_operands == 3: - return TensorProductUniform3x1d(d, device, math_dtype), None + return TensorProductUniform3x1d(d, device, math_dtype) else: - return None, TensorProductUniform4x1d(d, device, math_dtype) + return TensorProductUniform4x1d(d, device, math_dtype) supported_targets = [ stp.Subscripts(subscripts) @@ -436,9 +414,9 @@ def _tensor_product_cuda( ) if descriptor.num_operands == 3: - return FusedTensorProductOp3(descriptor, perm[:2], device, math_dtype), None + return FusedTensorProductOp3(descriptor, perm[:2], device, math_dtype) elif descriptor.num_operands == 4: - return None, FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) + return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor: @@ -489,10 +467,9 @@ def __repr__(self) -> str: def forward( self, - x0: torch.Tensor, - x1: torch.Tensor, + inputs: List[torch.Tensor] ) -> torch.Tensor: - x0, x1 = self._perm(x0, x1) + x0, x1 = self._perm(inputs[0], inputs[1]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -548,11 +525,9 @@ def __repr__(self) -> str: def forward( self, - x0: torch.Tensor, - x1: torch.Tensor, - x2: torch.Tensor, + inputs: List[torch.Tensor] ) -> torch.Tensor: - x0, x1, x2 = self._perm(x0, x1, x2) + x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim @@ -601,7 +576,8 @@ def __init__( def __repr__(self): return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" - def forward(self, x0:torch.Tensor, x1:torch.Tensor): + def forward(self, inputs: List[torch.Tensor]): + x0, x1 = inputs assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -653,7 +629,8 @@ def __init__( def __repr__(self): return f"TensorProductUniform4x1d({self.descriptor} (output last operand))" - def forward(self, x0, x1, x2): + def forward(self, inputs: List[torch.Tensor]): + x0, x1, x2 = inputs assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 4c40036..252e45e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -36,19 +36,22 @@ def __init__( source: cue.IrrepsLayout, target: cue.IrrepsLayout, device: Optional[torch.device] = None, + use_fallback: Optional[bool] = False ): super().__init__() if (source, target) == (cue.mul_ir, cue.ir_mul): self.f = TransposeSegments( - [(mul, ir.dim) for mul, ir in irreps], device=device + [(mul, ir.dim) for mul, ir in irreps], device=device, + use_fallback = use_fallback ) elif (source, target) == (cue.ir_mul, cue.mul_ir): self.f = TransposeSegments( - [(ir.dim, mul) for mul, ir in irreps], device=device + [(ir.dim, mul) for mul, ir in irreps], device=device, + use_fallback = use_fallback ) else: - self.f = _Identity() + self.f = torch.nn.Identity() self.source, self.target = source, target @@ -59,7 +62,7 @@ def __repr__(self): return f"TransposeIrrepsLayout({self.source} -> {self.target})" def forward( - self, x: torch.Tensor, use_fallback: Optional[bool] = None + self, x: torch.Tensor ) -> torch.Tensor: r""" Perform the transposition. @@ -74,17 +77,13 @@ def forward( torch.Tensor: The transposed tensor. """ - return self.f(x, use_fallback=use_fallback) - - -class _Identity(torch.nn.Module): - def forward(self, x: torch.Tensor, **kwargs): - return x + return self.f(x) class TransposeSegments(torch.nn.Module): def __init__( - self, segments: list[tuple[int, int]], device: Optional[torch.device] = None + self, segments: list[tuple[int, int]], device: Optional[torch.device] = None, + use_fallback: Optional[bool] = False ): super().__init__() @@ -97,8 +96,8 @@ def __init__( self.f_cuda = None else: self.f_cuda = _transpose(info).to(device=device) - - self.f = _transpose_segments_fx(segments).to(device=device) + if use_fallback: + self.f = _transpose_segments_fx(segments).to(device=device) else: self.f_cuda = torch.nn.Identity() self.f = torch.nn.Identity() @@ -107,7 +106,7 @@ def __repr__(self): return "TransposeSegments()" def forward( - self, x: torch.Tensor, use_fallback: Optional[bool] = None + self, x: torch.Tensor ) -> torch.Tensor: """ Perform the transposition of the input tensor using either a CUDA kernel or a PyTorch fallback. @@ -131,20 +130,10 @@ def forward( RuntimeError If `use_fallback` is `False` and a CUDA kernel is not available or the input is not on CUDA. """ - if ( - x.device.type == "cuda" - and self.f_cuda is not None - and (use_fallback is not True) - ): + if self.f_cuda is not None: return self.f_cuda(x) - - if use_fallback is False: - if self.f_cuda is not None: - raise RuntimeError("CUDA kernel available but input is not on CUDA") - else: - raise RuntimeError("No CUDA kernel available") - - return self.f(x) + else: + return self.f(x) def _transpose_segments_fx(segments: list[tuple[int, int]]) -> torch.nn.Module: diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index e040687..f9f6a7b 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -75,6 +75,15 @@ def test_performance_cuda_vs_fx( layout=cue.ir_mul, device=device, math_dtype=math_dtype, + use_fallback=False, + optimize_fallback=True, + ) + m1 = cuet.EquivariantTensorProduct( + e, + layout=cue.ir_mul, + device=device, + math_dtype=math_dtype, + use_fallback=True, optimize_fallback=True, ) @@ -84,15 +93,19 @@ def test_performance_cuda_vs_fx( ] for _ in range(10): - m(inputs, use_fallback=False) - m(inputs, use_fallback=True) + m(inputs) + m1(inputs) + + def f(): + m(inputs) + torch.cuda.synchronize() - def f(ufb: bool): - m(inputs, use_fallback=ufb) + def f1(): + m1(inputs) torch.cuda.synchronize() - t0 = timeit.Timer(lambda: f(False)).timeit(number=10) - t1 = timeit.Timer(lambda: f(True)).timeit(number=10) + t0 = timeit.Timer(f).timeit(number=10) + t1 = timeit.Timer(f1).timeit(number=10) assert t0 < t1 @@ -129,18 +142,20 @@ def test_precision_cuda_vs_fx( layout=cue.ir_mul, device=device, math_dtype=math_dtype, + use_fallback=False ) - y0 = m(inputs, use_fallback=False) + y0 = m(inputs) m = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, device=device, math_dtype=torch.float64, + use_fallback=True, optimize_fallback=True, ) - inputs = map(lambda x: x.to(torch.float64), inputs) - y1 = m(inputs, use_fallback=True).to(dtype) + inputs = [x.to(torch.float64) for x in inputs] + y1 = m(inputs).to(dtype) torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) @@ -149,10 +164,10 @@ def test_compile(): e = cue.descriptors.symmetric_contraction( cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] ) - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False) + m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, device="cuda", optimize_fallback=False) m_compile = torch.compile(m, fullgraph=True) - input1 = torch.randn(100, e.inputs[0].irreps.dim) - input2 = torch.randn(100, e.inputs[1].irreps.dim) + input1 = torch.randn(100, e.inputs[0].irreps.dim).cuda() + input2 = torch.randn(100, e.inputs[1].irreps.dim).cuda() m_compile([input1, input2]) @@ -160,8 +175,11 @@ def test_script(): e = cue.descriptors.symmetric_contraction( cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] ) - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False) + m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, + use_fallback=False, + device="cuda", + optimize_fallback=False) m_script = torch.jit.script(m) - input1 = torch.randn(100, e.inputs[0].irreps.dim) - input2 = torch.randn(100, e.inputs[1].irreps.dim) + input1 = torch.randn(100, e.inputs[0].irreps.dim).cuda() + input2 = torch.randn(100, e.inputs[1].irreps.dim).cuda() m_script([input1, input2]) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 95dfc4d..909d71d 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -64,8 +64,9 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( device = torch.device("cuda:0") m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=math_dtype, device=device, optimize_fallback=False + ds, math_dtype=math_dtype, device=device, use_fallback=False ) + m = torch.jit.script(m) x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True) i0 = torch.tensor([0, 1, 0], dtype=torch.int32, device=device) @@ -75,11 +76,11 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( x0_ = x0.clone().to(torch.float64) x1_ = x1.clone().to(torch.float64) - out1 = m(x0, i0, x1, use_fallback=False) + out1 = m(x0, i0, x1) m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=torch.float64, device=device, optimize_fallback=True + ds, math_dtype=torch.float64, device=device, use_fallback=True, optimize_fallback=True ) - out2 = m(x0_, i0, x1_, use_fallback=True) + out2 = m(x0_, i0, x1_) assert out1.dtype == dtype @@ -121,19 +122,19 @@ def test_math_dtype( ds = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ).ds - m = cuet.IWeightedSymmetricTensorProduct(ds, math_dtype=math_dtype, device=device) + m = cuet.IWeightedSymmetricTensorProduct(ds, math_dtype=math_dtype, device=device, use_fallback=False) x0 = torch.randn((20, m.x0_size), dtype=dtype, device=device) i0 = torch.randint(0, m.x0_size, (1000,), dtype=torch.int32, device=device) x1 = torch.randn((i0.size(0), m.x1_size), dtype=dtype, device=device) - out1 = m(x0, i0, x1, use_fallback=False) + out1 = m(x0, i0, x1) # .to should have no effect for param in m.parameters(): assert False # no parameters m = m.to(torch.float64) - out2 = m(x0, i0, x1, use_fallback=False) + out2 = m(x0, i0, x1) assert out1.dtype == dtype assert out2.dtype == dtype diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index dc6c2e9..c7bf8e2 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -118,7 +118,7 @@ def test_primitive_tensor_product_cuda_vs_fx( d, device=device, math_dtype=torch.float64, use_fallback=True, optimize_fallback=False ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] - out2 = m(inputs_, use_fallback=True) + out2 = m(inputs_) assert out1.shape[:-1] == torch.broadcast_shapes(*batches) assert out1.dtype == dtype diff --git a/cuequivariance_torch/tests/primitives/transpose_test.py b/cuequivariance_torch/tests/primitives/transpose_test.py index cd39cfc..67ad700 100644 --- a/cuequivariance_torch/tests/primitives/transpose_test.py +++ b/cuequivariance_torch/tests/primitives/transpose_test.py @@ -42,5 +42,5 @@ def test_transpose(use_fallback: bool, dtype: torch.dtype): [[1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 10, 12, 11, 13]], dtype=dtype ).cuda() - m = cuet.TransposeSegments(segments).cuda() - torch.testing.assert_close(m(x, use_fallback=use_fallback), xt) + m = cuet.TransposeSegments(segments, use_fallback=use_fallback).cuda() + torch.testing.assert_close(m(x), xt) From 79e7c5f750879c0b61f1d8b9e8b1ebdd15ada7ac Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 4 Dec 2024 20:32:47 -0800 Subject: [PATCH 19/44] Restoring debug logging Signed-off-by: Boris Fomitchev --- .../primitives/symmetric_tensor_product.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 8f8ea0d..c5ac0f0 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -324,9 +324,10 @@ def forward( i0 = i0.to(torch.int32) x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u) x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u) - # logger.debug( - # f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" - # ) + if not torch.jit.is_scripting(): + logger.debug( + f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" + ) out = self.f(x1, x0, i0) out = out.reshape(out.shape[0], out.shape[1] * self.u) return out From 6c5cdb023d1b57a937195265112d0c51ee0a05eb Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 5 Dec 2024 01:01:37 -0800 Subject: [PATCH 20/44] Parameterized script test Signed-off-by: Boris Fomitchev --- .../primitives/equivariant_tensor_product.py | 3 +- .../primitives/tensor_product.py | 34 +++------------- .../equivariant_tensor_product_test.py | 39 +++++++++++++------ .../tests/primitives/script_test.py | 8 ++-- 4 files changed, 37 insertions(+), 47 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 06ea215..549e38c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -55,8 +55,7 @@ def forward( inputs: List[torch.Tensor], indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: - x0 = inputs[0] - x1 = inputs[1] + x0, x1 = inputs if indices is None: torch._assert( x0.ndim == 2, diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 941b34c..0a246cc 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -546,8 +546,7 @@ def forward( return out.reshape(shape + (out.shape[-1],)) - -class TensorProductUniform3x1d(torch.nn.Module): +class TensorProductUniform1d(torch.nn.Module): def __init__( self, descriptor: stp.SegmentedTensorProduct, @@ -573,10 +572,11 @@ def __init__( math_dtype=math_dtype, ).to(device=device) +class TensorProductUniform3x1d(TensorProductUniform1d): def __repr__(self): return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" - def forward(self, inputs: List[torch.Tensor]): + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = inputs assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -595,36 +595,12 @@ def forward(self, inputs: List[torch.Tensor]): f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) - out = self._f(x0, x1) + out = self._f(x0, x1, x0) return out.reshape(shape + (out.shape[-1],)) -class TensorProductUniform4x1d(torch.nn.Module): - def __init__( - self, - descriptor: stp.SegmentedTensorProduct, - device: Optional[torch.device], - math_dtype: torch.dtype, - ): - super().__init__() - import cuequivariance_ops_torch as ops - - self.descriptor = descriptor - - assert len(descriptor.subscripts.modes()) == 1 - assert descriptor.all_same_segment_shape() - assert descriptor.coefficient_subscripts == "" - u = next(iter(descriptor.get_dims(descriptor.subscripts.modes()[0]))) - - self._f = ops.TensorProductUniform1d( - operand_dim=[ope.ndim for ope in descriptor.operands], - operand_extent=u, - operand_num_segments=[ope.num_segments for ope in descriptor.operands], - path_indices=[path.indices for path in descriptor.paths], - path_coefficients=[float(path.coefficients) for path in descriptor.paths], - math_dtype=math_dtype, - ).to(device=device) +class TensorProductUniform4x1d(TensorProductUniform1d): def __repr__(self): return f"TensorProductUniform4x1d({self.descriptor} (output last operand))" diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 42736a6..518d150 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -76,8 +76,8 @@ def test_performance_cuda_vs_fx( device=device, math_dtype=math_dtype, use_fallback=False, - optimize_fallback=True, ) + m1 = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, @@ -95,14 +95,17 @@ def test_performance_cuda_vs_fx( for _ in range(10): m(inputs) m1(inputs) + torch.cuda.synchronize() def f(): - m(inputs) - torch.cuda.synchronize() + ret = m(inputs) + ret = torch.sum(ret) + return ret def f1(): - m1(inputs) - torch.cuda.synchronize() + ret = m1(inputs) + ret = torch.sum(ret) + return ret t0 = timeit.Timer(f).timeit(number=10) t1 = timeit.Timer(f1).timeit(number=10) @@ -170,16 +173,28 @@ def test_compile(): input2 = torch.randn(100, e.inputs[1].irreps.dim).cuda() m_compile([input1, input2]) -def test_script(): - e = cue.descriptors.symmetric_contraction( - cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] - ) +@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +def test_script( + e: cue.EquivariantTensorProduct, + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, +): + + device = torch.device("cuda:0") + m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, use_fallback=False, device="cuda", optimize_fallback=False) + inputs = [ + torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) + for inp in e.inputs + ] + res = m(inputs) m_script = torch.jit.script(m) - input1 = torch.randn(100, e.inputs[0].irreps.dim).cuda() - input2 = torch.randn(100, e.inputs[1].irreps.dim).cuda() - m_script([input1, input2]) + # res_script = m_script(inputs) + # torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/primitives/script_test.py b/cuequivariance_torch/tests/primitives/script_test.py index 44e880f..37b2a0c 100644 --- a/cuequivariance_torch/tests/primitives/script_test.py +++ b/cuequivariance_torch/tests/primitives/script_test.py @@ -46,7 +46,7 @@ def test_script_fused_tp_3(): module = FusedTensorProductOp3(d, (0, 1), torch.device("cuda:0"), torch.float32) module = torch.jit.script(module) - assert module(x0, x1).shape == (batch, d.operands[2].size) + assert module([x0, x1]).shape == (batch, d.operands[2].size) def test_script_fused_tp_4(): @@ -67,7 +67,7 @@ def test_script_fused_tp_4(): module = FusedTensorProductOp4(d, (0, 1, 2), torch.device("cuda:0"), torch.float32) module = torch.jit.script(module) - assert module(x0, x1, x2).shape == (batch, d.operands[3].size) + assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) def test_script_uniform_tp_3(): @@ -86,7 +86,7 @@ def test_script_uniform_tp_3(): module = TensorProductUniform3x1d(d, torch.device("cuda:0"), torch.float32) module = torch.jit.script(module) - assert module(x0, x1).shape == (batch, d.operands[2].size) + assert module([x0, x1]).shape == (batch, d.operands[2].size) def test_script_uniform_tp_4(): @@ -106,4 +106,4 @@ def test_script_uniform_tp_4(): module = TensorProductUniform4x1d(d, torch.device("cuda:0"), torch.float32) module = torch.jit.script(module) - assert module(x0, x1, x2).shape == (batch, d.operands[3].size) + assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) From e21c45f57d7a8c3a4cbc5f7affc9445cb3b39eb7 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 5 Dec 2024 12:04:35 -0800 Subject: [PATCH 21/44] Fixed transpose for script(), script_test successful Signed-off-by: Boris Fomitchev --- .../primitives/equivariant_tensor_product.py | 49 +++++++++++++------ .../primitives/symmetric_tensor_product.py | 9 ++-- .../primitives/tensor_product.py | 1 - .../equivariant_tensor_product_test.py | 8 +-- 4 files changed, 43 insertions(+), 24 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 549e38c..a8dbd5e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -26,6 +26,31 @@ def __init__(self, tp): super().__init__() self.tp = tp +class Transpose1Dispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor] + ): + inputs[0] = self.tp[0](inputs[0]) + +class Transpose2Dispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor] + ): + inputs[0] = self.tp[0](inputs[0]) + inputs[1] = self.tp[1](inputs[1]) + +class Transpose3Dispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor] + ): + inputs[0] = self.tp[0](inputs[0]) + inputs[1] = self.tp[1](inputs[1]) + inputs[2] = self.tp[2](inputs[2]) + +TRANSPOSE_DISPATCHERS = [Transpose1Dispatcher, Transpose2Dispatcher, Transpose3Dispatcher] class TPDispatcher(Dispatcher): def forward( @@ -61,16 +86,9 @@ def forward( x0.ndim == 2, f"Expected x0 to have shape (batch, dim), got {x0.shape}", ) - if x0.shape[0] == 1: - indices = torch.zeros( - (x1.shape[0],), dtype=torch.int32, device=x1.device - ) - else: # x0.shape[0] == x1.shape[0]: - indices = torch.arange( - x1.shape[0], dtype=torch.int32, device=x1.device - ) - # borisf : why was it here ? - # if indices is not None: + indices = torch.arange( + x1.shape[0], dtype=torch.int32, device=x1.device + ) return self.tp(x0, indices, x1) class EquivariantTensorProduct(torch.nn.Module): @@ -140,9 +158,9 @@ def __init__( self.layout_in = layout_in = tuple(map(default_layout, layout_in)) self.layout_out = layout_out = default_layout(layout_out) - self.transpose_in = torch.nn.ModuleList() + transpose_in = torch.nn.ModuleList() for layout_used, input_expected in zip(layout_in, e.inputs): - self.transpose_in.append( + transpose_in.append( cuet.TransposeIrrepsLayout( input_expected.irreps, source=layout_used, @@ -151,6 +169,9 @@ def __init__( use_fallback = use_fallback ) ) + # script() requires literal addressing and fails to eliminate dead branches + self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in)-1](transpose_in) + self.transpose_out = cuet.TransposeIrrepsLayout( e.output.irreps, source=e.output.layout, @@ -212,9 +233,7 @@ def forward( assert a.shape[-1] == dim # Transpose inputs - inputs[0] = self.transpose_in[0](inputs[0]) - if len(self.transpose_in) > 1: - inputs[1] = self.transpose_in[1](inputs[1]) + self.transpose_in.forward(inputs) # Compute tensor product output = self.tp(inputs, indices) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index fc7eb47..4553d01 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -85,6 +85,7 @@ def __init__( descriptors, device=device, math_dtype=math_dtype, + use_fallback = use_fallback, optimize_fallback=optimize_fallback, ) @@ -153,7 +154,7 @@ def __init__( self.x2_size = d.operands[-1].size self.has_cuda = False - if not use_fallback == True: + if use_fallback is None or not use_fallback: try: self.f = CUDAKernel(descriptors, device, math_dtype) self.has_cuda = True @@ -163,15 +164,15 @@ def __init__( except ImportError as e: logger.warning(f"Failed to initialize CUDA implementation: {e}") - if use_fallback == False: - raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available") - else: + if use_fallback is None or use_fallback: self.f = FallbackImpl( descriptors, device, math_dtype=math_dtype, optimize_fallback=optimize_fallback, ) + else: + raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available") def __repr__(self): has_cuda_kernel = ( diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 0a246cc..410c323 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -19,7 +19,6 @@ import torch import torch.fx -from torch.jit import Final from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 518d150..16f3cf2 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -187,14 +187,14 @@ def test_script( m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, use_fallback=False, - device="cuda", - optimize_fallback=False) + device="cuda") inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs ] + copy_inputs = [i.clone() for i in inputs] res = m(inputs) m_script = torch.jit.script(m) - # res_script = m_script(inputs) - # torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + res_script = m_script(copy_inputs) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) From 779dd9ccd02f8e6585ca2485b908cf5717421ca6 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 5 Dec 2024 12:48:57 -0800 Subject: [PATCH 22/44] Fixed input mutation Signed-off-by: Boris Fomitchev --- .../primitives/equivariant_tensor_product.py | 22 ++++++++++++------- .../equivariant_tensor_product_test.py | 3 +-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index a8dbd5e..3a18f76 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -31,24 +31,30 @@ def forward( self, inputs: List[torch.Tensor] ): - inputs[0] = self.tp[0](inputs[0]) - + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + return ret + class Transpose2Dispatcher(Dispatcher): def forward( self, inputs: List[torch.Tensor] ): - inputs[0] = self.tp[0](inputs[0]) - inputs[1] = self.tp[1](inputs[1]) + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + ret[1] = self.tp[1](ret[1]) + return ret class Transpose3Dispatcher(Dispatcher): def forward( self, inputs: List[torch.Tensor] ): - inputs[0] = self.tp[0](inputs[0]) - inputs[1] = self.tp[1](inputs[1]) - inputs[2] = self.tp[2](inputs[2]) + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + ret[1] = self.tp[1](ret[1]) + ret[2] = self.tp[1](ret[2]) + return ret TRANSPOSE_DISPATCHERS = [Transpose1Dispatcher, Transpose2Dispatcher, Transpose3Dispatcher] @@ -233,7 +239,7 @@ def forward( assert a.shape[-1] == dim # Transpose inputs - self.transpose_in.forward(inputs) + inputs = self.transpose_in(inputs) # Compute tensor product output = self.tp(inputs, indices) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 16f3cf2..60bbacf 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -192,9 +192,8 @@ def test_script( torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs ] - copy_inputs = [i.clone() for i in inputs] res = m(inputs) m_script = torch.jit.script(m) - res_script = m_script(copy_inputs) + res_script = m_script(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) From c315857c2128380d562b231af7b198d2ecacbad1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 5 Dec 2024 18:10:49 -0800 Subject: [PATCH 23/44] Fixed tests Signed-off-by: Boris Fomitchev --- .../primitives/equivariant_tensor_product.py | 17 +++++++- .../primitives/symmetric_tensor_product.py | 2 +- .../primitives/tensor_product.py | 10 ++--- .../tests/operations/linear_test.py | 37 ++++++++++------ .../operations/spherical_harmonics_test.py | 8 ++-- .../operations/symmetric_contraction_test.py | 2 +- .../tests/operations/tp_channel_wise_test.py | 42 ++++++++++++------- .../operations/tp_fully_connected_test.py | 14 ++++--- .../equivariant_tensor_product_test.py | 29 +++++++++---- 9 files changed, 106 insertions(+), 55 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 3a18f76..cacbd9c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -53,10 +53,22 @@ def forward( ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) - ret[2] = self.tp[1](ret[2]) + ret[2] = self.tp[2](ret[2]) return ret -TRANSPOSE_DISPATCHERS = [Transpose1Dispatcher, Transpose2Dispatcher, Transpose3Dispatcher] +class Transpose4Dispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor] + ): + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + ret[1] = self.tp[1](ret[1]) + ret[2] = self.tp[2](ret[2]) + ret[3] = self.tp[3](ret[3]) + return ret + +TRANSPOSE_DISPATCHERS = [Transpose1Dispatcher, Transpose2Dispatcher, Transpose3Dispatcher, Transpose4Dispatcher] class TPDispatcher(Dispatcher): def forward( @@ -175,6 +187,7 @@ def __init__( use_fallback = use_fallback ) ) + # script() requires literal addressing and fails to eliminate dead branches self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in)-1](transpose_in) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 4553d01..f02be9c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -325,7 +325,7 @@ def forward( i0 = i0.to(torch.int32) x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u) x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 410c323..0540eda 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -306,7 +306,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.descriptor = descriptor def forward(self, args:List[torch.Tensor]): - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): for oid, arg in enumerate(args): torch._assert( arg.shape[-1] == self.descriptor.operands[oid].size, @@ -476,7 +476,7 @@ def forward( x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) @@ -536,7 +536,7 @@ def forward( x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting and not torch.compiler.is_compiling(): logger.debug( f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) @@ -589,7 +589,7 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: if x1.ndim == 1: x1 = x1.unsqueeze(0) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" ) @@ -622,7 +622,7 @@ def forward(self, inputs: List[torch.Tensor]): if x2.ndim == 1: x2 = x2.unsqueeze(0) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index 1f786b7..d0d8a41 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -45,16 +45,26 @@ def test_linear_fwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, + use_fallback=False + ) + linear_fx = cuet.Linear( + irreps_in, + irreps_out, + layout=layout, + shared_weights=shared_weights, + device="cuda", + dtype=torch.float64, + use_fallback=True ) x = torch.randn(10, irreps_in.dim, dtype=torch.float64).cuda() if shared_weights: y = linear(x) - y_fx = linear(x, use_fallback=True) + y_fx = linear_fx(x) else: w = torch.randn(10, linear.weight_numel, dtype=torch.float64).cuda() y = linear(x, w) - y_fx = linear(x, w, use_fallback=True) + y_fx = linear_fx(x, w) assert y.shape == (10, irreps_out.dim) @@ -71,17 +81,18 @@ def test_linear_bwd_bwd( layout: cue.IrrepsLayout, shared_weights: bool, ): - linear = cuet.Linear( - irreps_in, - irreps_out, - layout=layout, - shared_weights=shared_weights, - device="cuda", - dtype=torch.float64, - ) - outputs = dict() for use_fallback in [True, False]: + linear = cuet.Linear( + irreps_in, + irreps_out, + layout=layout, + shared_weights=shared_weights, + device="cuda", + dtype=torch.float64, + use_fallback=use_fallback + ) + # reset the seed to ensure the same initialization torch.manual_seed(0) @@ -90,12 +101,12 @@ def test_linear_bwd_bwd( ) if shared_weights: - y = linear(x, use_fallback=use_fallback) + y = linear(x) else: w = torch.randn( 10, linear.weight_numel, requires_grad=True, dtype=torch.float64 ).cuda() - y = linear(x, w, use_fallback=use_fallback) + y = linear(x, w) (grad,) = torch.autograd.grad( y.pow(2).sum(), diff --git a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py index e1f07ab..6024401 100644 --- a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py @@ -26,15 +26,15 @@ ) @pytest.mark.parametrize("l", [1, 2, 3]) def test_spherical_harmonics(l: int, dtype, tol): - vec = torch.randn(3, dtype=dtype) + vec = torch.randn(3, dtype=dtype, device="cuda") axis = np.random.randn(3) angle = np.random.rand() scale = 1.3 yl = cuet.spherical_harmonics([l], vec, False) - R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype) - Rl = torch.from_numpy(cue.SO3(l).rotation(axis, angle)).to(dtype) + R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).cuda() + Rl = torch.from_numpy(cue.SO3(l).rotation(axis, angle)).to(dtype).cuda() yl1 = cuet.spherical_harmonics([l], scale * R @ vec, False) yl2 = scale**l * Rl @ yl @@ -43,7 +43,7 @@ def test_spherical_harmonics(l: int, dtype, tol): def test_spherical_harmonics_full(): - vec = torch.randn(3) + vec = torch.randn(3, device="cuda") ls = [0, 1, 2, 3] yl = cuet.spherical_harmonics(ls, vec, False) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 62ba30e..5576397 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -30,7 +30,7 @@ @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("original_mace", [True, False]) -@pytest.mark.parametrize("batch", [0, 32]) +@pytest.mark.parametrize("batch", [1, 32]) def test_symmetric_contraction(dtype, layout, original_mace, batch): mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 9540c73..4e12f73 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("irreps3", list_of_irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) -@pytest.mark.parametrize("batch", [0, 32]) +@pytest.mark.parametrize("batch", [1, 32]) def test_channel_wise( irreps1: cue.Irreps, irreps2: cue.Irreps, @@ -50,19 +50,30 @@ def test_channel_wise( device="cuda", dtype=torch.float64, ) + m_fx = cuet.ChannelWiseTensorProduct( + irreps1, + irreps2, + irreps3, + shared_weights=True, + internal_weights=True, + layout=layout, + device="cuda", + dtype=torch.float64, + use_fallback=True + ) x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() - out1 = m(x1, x2, use_fallback=use_fallback) + out1 = m(x1, x2) d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d d = d.squeeze_modes("v") assert d.subscripts == "u,iu,j,ku+ijk" if layout == cue.mul_ir: d = d.add_or_transpose_modes("u,ui,j,uk+ijk") - mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() - out2 = mfx([m.weight, x1, x2], use_fallback=True) + mfx = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).cuda() + out2 = mfx([m.weight, x1, x2]) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) @@ -72,17 +83,6 @@ def test_channel_wise_bwd_bwd(): irreps2 = cue.Irreps("SO3", "0 + 1") irreps3 = cue.Irreps("SO3", "0 + 1") - m = cuet.ChannelWiseTensorProduct( - irreps1, - irreps2, - irreps3, - shared_weights=True, - internal_weights=False, - layout=cue.ir_mul, - device="cuda", - dtype=torch.float64, - ) - x1 = torch.randn( 32, irreps1.dim, device="cuda", requires_grad=True, dtype=torch.float64 ) @@ -95,6 +95,18 @@ def test_channel_wise_bwd_bwd(): outputs = {} for use_fallback in [True, False]: + m = cuet.ChannelWiseTensorProduct( + irreps1, + irreps2, + irreps3, + shared_weights=True, + internal_weights=False, + layout=cue.ir_mul, + device="cuda", + dtype=torch.float64, + use_fallback=use_fallback + ) + (grad1, grad2, grad3) = torch.autograd.grad( m(x1, x2, w).pow(2).sum(), (x1, x2, w), create_graph=True ) diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index 4e197fd..49c65a0 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -47,35 +47,37 @@ def test_fully_connected( layout=layout, device="cuda", dtype=torch.float64, + use_fallback=use_fallback ) x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() - out1 = m(x1, x2, use_fallback=use_fallback) + out1 = m(x1, x2) d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d if layout == cue.mul_ir: d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") - mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() + mfx = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).cuda() out2 = mfx( [m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], - use_fallback=True, ).to(out1.dtype) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) def test_compile(): + device = "cuda" m = cuet.FullyConnectedTensorProduct( irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"), irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"), irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), layout=cue.mul_ir, - optimize_fallback=False, + device=device, + use_fallback=False ) m_compile = torch.compile(m, fullgraph=True) - input1 = torch.randn(100, m.irreps_in1.dim) - input2 = torch.randn(100, m.irreps_in2.dim) + input1 = torch.randn(100, m.irreps_in1.dim, device=device) + input2 = torch.randn(100, m.irreps_in2.dim, device=device) m_compile(input1, input2) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 60bbacf..8f3ea2f 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -163,15 +163,28 @@ def test_precision_cuda_vs_fx( torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) -def test_compile(): - e = cue.descriptors.symmetric_contraction( - cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] - ) - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, device="cuda", optimize_fallback=False) +@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +def test_compile( + e: cue.EquivariantTensorProduct, + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, +): + device = torch.device("cuda:0") + + m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, + use_fallback=False, + device="cuda") + inputs = [ + torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) + for inp in e.inputs + ] + res = m(inputs) m_compile = torch.compile(m, fullgraph=True) - input1 = torch.randn(100, e.inputs[0].irreps.dim).cuda() - input2 = torch.randn(100, e.inputs[1].irreps.dim).cuda() - m_compile([input1, input2]) + res_script = m_compile(inputs) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) @pytest.mark.parametrize("e", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) From ab590c8d8d0e875e8d88dfe267e746d225782b5f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 01:37:44 -0800 Subject: [PATCH 24/44] format with black --- .../primitives/equivariant_tensor_product.py | 65 +++++++++---------- .../primitives/symmetric_tensor_product.py | 34 +++++----- .../primitives/tensor_product.py | 53 ++++++++------- .../primitives/transpose.py | 26 ++++---- .../tests/operations/linear_test.py | 6 +- .../tests/operations/tp_channel_wise_test.py | 4 +- .../operations/tp_fully_connected_test.py | 4 +- .../equivariant_tensor_product_test.py | 22 +++---- .../symmetric_tensor_product_test.py | 10 ++- .../tests/primitives/tensor_product_test.py | 9 ++- 10 files changed, 121 insertions(+), 112 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cacbd9c..1808ada 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -26,41 +26,33 @@ def __init__(self, tp): super().__init__() self.tp = tp + class Transpose1Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) return ret - + + class Transpose2Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) return ret + class Transpose3Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) ret[2] = self.tp[2](ret[2]) return ret + class Transpose4Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) @@ -68,7 +60,14 @@ def forward( ret[3] = self.tp[3](ret[3]) return ret -TRANSPOSE_DISPATCHERS = [Transpose1Dispatcher, Transpose2Dispatcher, Transpose3Dispatcher, Transpose4Dispatcher] + +TRANSPOSE_DISPATCHERS = [ + Transpose1Dispatcher, + Transpose2Dispatcher, + Transpose3Dispatcher, + Transpose4Dispatcher, +] + class TPDispatcher(Dispatcher): def forward( @@ -80,9 +79,9 @@ def forward( # TODO: at some point we will have kernel for this assert len(inputs) >= 1 inputs[0] = inputs[0][indices] - return self.tp(inputs) + return self.tp(inputs) + - class SymmetricTPDispatcher(Dispatcher): def forward( self, @@ -91,12 +90,13 @@ def forward( ) -> torch.Tensor: assert indices is None return self.tp(inputs[0]) - + + class IWeightedSymmetricTPDispatcher(Dispatcher): def forward( - self, - inputs: List[torch.Tensor], - indices: Optional[torch.Tensor] = None, + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1 = inputs if indices is None: @@ -104,11 +104,10 @@ def forward( x0.ndim == 2, f"Expected x0 to have shape (batch, dim), got {x0.shape}", ) - indices = torch.arange( - x1.shape[0], dtype=torch.int32, device=x1.device - ) + indices = torch.arange(x1.shape[0], dtype=torch.int32, device=x1.device) return self.tp(x0, indices, x1) + class EquivariantTensorProduct(torch.nn.Module): r"""Equivariant tensor product. @@ -146,7 +145,7 @@ class EquivariantTensorProduct(torch.nn.Module): ... [0., 0., 0., 0., 0., 0.]]) """ - + def __init__( self, e: cue.EquivariantTensorProduct, @@ -184,19 +183,19 @@ def __init__( source=layout_used, target=input_expected.layout, device=device, - use_fallback = use_fallback + use_fallback=use_fallback, ) ) # script() requires literal addressing and fails to eliminate dead branches - self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in)-1](transpose_in) - + self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in) - 1](transpose_in) + self.transpose_out = cuet.TransposeIrrepsLayout( e.output.irreps, source=e.output.layout, target=layout_out, device=device, - use_fallback = use_fallback + use_fallback=use_fallback, ) if any(d.num_operands != e.num_inputs + 1 for d in e.ds): @@ -228,7 +227,7 @@ def __init__( e.ds[0], device=device, math_dtype=math_dtype, - use_fallback = use_fallback, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index f02be9c..bb30e39 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -57,7 +57,7 @@ def __init__( d0, device=device, math_dtype=math_dtype, - use_fallback = use_fallback, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) else: @@ -85,13 +85,11 @@ def __init__( descriptors, device=device, math_dtype=math_dtype, - use_fallback = use_fallback, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) - def forward( - self, x0: torch.Tensor - ) -> torch.Tensor: + def forward(self, x0: torch.Tensor) -> torch.Tensor: r""" Perform the forward pass of the indexed symmetric tensor product operation. @@ -109,7 +107,7 @@ def forward( out = self.f( torch.ones((1, 1), dtype=x0.dtype, device=x0.device), torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device), - x0 + x0, ) if self.f0 is not None: out += self.f0([]) @@ -153,7 +151,7 @@ def __init__( self.x1_size = d.operands[1].size self.x2_size = d.operands[-1].size self.has_cuda = False - + if use_fallback is None or not use_fallback: try: self.f = CUDAKernel(descriptors, device, math_dtype) @@ -163,7 +161,7 @@ def __init__( logger.info(f"Failed to initialize CUDA implementation: {e}") except ImportError as e: logger.warning(f"Failed to initialize CUDA implementation: {e}") - + if use_fallback is None or use_fallback: self.f = FallbackImpl( descriptors, @@ -172,11 +170,15 @@ def __init__( optimize_fallback=optimize_fallback, ) else: - raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available") + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available" + ) def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.has_cuda is not None else "(without CUDA kernel)" + "(with CUDA kernel)" + if self.has_cuda is not None + else "(without CUDA kernel)" ) return f"IWeightedSymmetricTensorProduct({has_cuda_kernel})" @@ -212,16 +214,13 @@ def forward( ) shape = broadcast_shapes([i0.shape, x1.shape[:-1]]) i0 = i0.expand(shape).reshape((prod(shape),)) - x1 = x1.expand(shape + (x1.shape[-1],)).reshape( - (prod(shape), x1.shape[-1]) - ) - - out = self.f(x0, i0, x1) + x1 = x1.expand(shape + (x1.shape[-1],)).reshape((prod(shape), x1.shape[-1])) + + out = self.f(x0, i0, x1) out = out.reshape(shape + (self.x2_size,)) return out - def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): if len(descriptors) == 0: raise ValueError("stps must contain at least one STP.") @@ -359,6 +358,5 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) - for f in self.fs + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 0540eda..e0d5e98 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -23,12 +23,14 @@ logger = logging.getLogger(__name__) + def prod(numbers: List[int]): product = 1 for num in numbers: product *= num return product + def broadcast_shapes(shapes: List[List[int]]): if torch.jit.is_scripting(): max_len = 0 @@ -47,15 +49,23 @@ def broadcast_shapes(shapes: List[List[int]]): if isinstance(shape, (tuple, list)): for i in range(-1, -1 - len(shape), -1): if shape[i] < 0: - raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})" - .format(shape[i], shape[i])) + raise RuntimeError( + "Trying to create tensor with negative dimension ({}): ({})".format( + shape[i], shape[i] + ) + ) if shape[i] == 1 or shape[i] == result[i]: continue if result[i] != 1: - raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape") + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) result[i] = shape[i] else: - raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape) + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) return torch.Size(result) else: return torch.functional.broadcast_shapes(*shapes) @@ -76,6 +86,7 @@ class TensorProduct(torch.nn.Module): RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. """ + def __init__( self, descriptor: stp.SegmentedTensorProduct, @@ -91,8 +102,8 @@ def __init__( math_dtype = torch.get_default_dtype() self.f = None self.has_cuda = False - - if not use_fallback == True: + + if not use_fallback == True: try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) self.has_cuda = True @@ -107,9 +118,11 @@ def __init__( "pip install cuequivariance-ops-torch-cu11\n" "pip install cuequivariance-ops-torch-cu12" ) - - if use_fallback == False: - raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available!") + + if use_fallback == False: + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available!" + ) else: self.f = _tensor_product_fx( descriptor, device, math_dtype, optimize_fallback is True @@ -118,7 +131,7 @@ def __init__( warnings.warn( "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." - ) + ) self._optimize_fallback = optimize_fallback def __repr__(self): @@ -216,9 +229,7 @@ def _tensor_product_fx( seg_shape = descriptor.get_segment_shape(-1, path) outputs += [ - out.reshape( - out.shape[: out.ndim - len(seg_shape)] + (prod(seg_shape),) - ) + out.reshape(out.shape[: out.ndim - len(seg_shape)] + (prod(seg_shape),)) ] if len(outputs) == 0: @@ -305,7 +316,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.module = module self.descriptor = descriptor - def forward(self, args:List[torch.Tensor]): + def forward(self, args: List[torch.Tensor]): if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): for oid, arg in enumerate(args): torch._assert( @@ -417,7 +428,7 @@ def _tensor_product_cuda( elif descriptor.num_operands == 4: return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) - + def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor: # Make x have shape (Z, x.shape[-1]) or (x.shape[-1],) if prod(leading_shape) > 1 and prod(x.shape[:-1]) == 1: @@ -464,10 +475,7 @@ def __init__( def __repr__(self) -> str: return f"FusedTensorProductOp3({self.descriptor} (output last operand))" - def forward( - self, - inputs: List[torch.Tensor] - ) -> torch.Tensor: + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = self._perm(inputs[0], inputs[1]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -522,10 +530,7 @@ def __init__( def __repr__(self) -> str: return f"FusedTensorProductOp4({self.descriptor} (output last operand))" - def forward( - self, - inputs: List[torch.Tensor] - ) -> torch.Tensor: + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -545,6 +550,7 @@ def forward( return out.reshape(shape + (out.shape[-1],)) + class TensorProductUniform1d(torch.nn.Module): def __init__( self, @@ -571,6 +577,7 @@ def __init__( math_dtype=math_dtype, ).to(device=device) + class TensorProductUniform3x1d(TensorProductUniform1d): def __repr__(self): return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 252e45e..9e7156e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -36,19 +36,21 @@ def __init__( source: cue.IrrepsLayout, target: cue.IrrepsLayout, device: Optional[torch.device] = None, - use_fallback: Optional[bool] = False + use_fallback: Optional[bool] = False, ): super().__init__() if (source, target) == (cue.mul_ir, cue.ir_mul): self.f = TransposeSegments( - [(mul, ir.dim) for mul, ir in irreps], device=device, - use_fallback = use_fallback + [(mul, ir.dim) for mul, ir in irreps], + device=device, + use_fallback=use_fallback, ) elif (source, target) == (cue.ir_mul, cue.mul_ir): self.f = TransposeSegments( - [(ir.dim, mul) for mul, ir in irreps], device=device, - use_fallback = use_fallback + [(ir.dim, mul) for mul, ir in irreps], + device=device, + use_fallback=use_fallback, ) else: self.f = torch.nn.Identity() @@ -61,9 +63,7 @@ def __init__( def __repr__(self): return f"TransposeIrrepsLayout({self.source} -> {self.target})" - def forward( - self, x: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" Perform the transposition. @@ -82,8 +82,10 @@ def forward( class TransposeSegments(torch.nn.Module): def __init__( - self, segments: list[tuple[int, int]], device: Optional[torch.device] = None, - use_fallback: Optional[bool] = False + self, + segments: list[tuple[int, int]], + device: Optional[torch.device] = None, + use_fallback: Optional[bool] = False, ): super().__init__() @@ -105,9 +107,7 @@ def __init__( def __repr__(self): return "TransposeSegments()" - def forward( - self, x: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform the transposition of the input tensor using either a CUDA kernel or a PyTorch fallback. diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index d0d8a41..26b9e5b 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -45,7 +45,7 @@ def test_linear_fwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, - use_fallback=False + use_fallback=False, ) linear_fx = cuet.Linear( irreps_in, @@ -54,7 +54,7 @@ def test_linear_fwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, - use_fallback=True + use_fallback=True, ) x = torch.randn(10, irreps_in.dim, dtype=torch.float64).cuda() @@ -90,7 +90,7 @@ def test_linear_bwd_bwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, - use_fallback=use_fallback + use_fallback=use_fallback, ) # reset the seed to ensure the same initialization diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 4e12f73..c48e1e1 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -59,7 +59,7 @@ def test_channel_wise( layout=layout, device="cuda", dtype=torch.float64, - use_fallback=True + use_fallback=True, ) x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() @@ -104,7 +104,7 @@ def test_channel_wise_bwd_bwd(): layout=cue.ir_mul, device="cuda", dtype=torch.float64, - use_fallback=use_fallback + use_fallback=use_fallback, ) (grad1, grad2, grad3) = torch.autograd.grad( diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index 49c65a0..d9b19b4 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -47,7 +47,7 @@ def test_fully_connected( layout=layout, device="cuda", dtype=torch.float64, - use_fallback=use_fallback + use_fallback=use_fallback, ) x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() @@ -74,7 +74,7 @@ def test_compile(): irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), layout=cue.mul_ir, device=device, - use_fallback=False + use_fallback=False, ) m_compile = torch.compile(m, fullgraph=True) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 8f3ea2f..59d44b8 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -77,7 +77,7 @@ def test_performance_cuda_vs_fx( math_dtype=math_dtype, use_fallback=False, ) - + m1 = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, @@ -141,11 +141,7 @@ def test_precision_cuda_vs_fx( for inp in e.inputs ] m = cuet.EquivariantTensorProduct( - e, - layout=cue.ir_mul, - device=device, - math_dtype=math_dtype, - use_fallback=False + e, layout=cue.ir_mul, device=device, math_dtype=math_dtype, use_fallback=False ) y0 = m(inputs) @@ -174,9 +170,9 @@ def test_compile( ): device = torch.device("cuda:0") - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, - use_fallback=False, - device="cuda") + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, use_fallback=False, device="cuda" + ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs @@ -186,6 +182,7 @@ def test_compile( res_script = m_compile(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + @pytest.mark.parametrize("e", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) def test_script( @@ -198,9 +195,9 @@ def test_script( device = torch.device("cuda:0") - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, - use_fallback=False, - device="cuda") + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, use_fallback=False, device="cuda" + ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs @@ -209,4 +206,3 @@ def test_script( m_script = torch.jit.script(m) res_script = m_script(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) - diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 909d71d..f5ab6aa 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -78,7 +78,11 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( out1 = m(x0, i0, x1) m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=torch.float64, device=device, use_fallback=True, optimize_fallback=True + ds, + math_dtype=torch.float64, + device=device, + use_fallback=True, + optimize_fallback=True, ) out2 = m(x0_, i0, x1_) @@ -122,7 +126,9 @@ def test_math_dtype( ds = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ).ds - m = cuet.IWeightedSymmetricTensorProduct(ds, math_dtype=math_dtype, device=device, use_fallback=False) + m = cuet.IWeightedSymmetricTensorProduct( + ds, math_dtype=math_dtype, device=device, use_fallback=False + ) x0 = torch.randn((20, m.x0_size), dtype=dtype, device=device) i0 = torch.randint(0, m.x0_size, (1000,), dtype=torch.int32, device=device) x1 = torch.randn((i0.size(0), m.x1_size), dtype=dtype, device=device) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index c7bf8e2..e4ba7cc 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -113,9 +113,13 @@ def test_primitive_tensor_product_cuda_vs_fx( ) m = torch.jit.script(m) out1 = m(inputs) - + m = cuet.TensorProduct( - d, device=device, math_dtype=torch.float64, use_fallback=True, optimize_fallback=False + d, + device=device, + math_dtype=torch.float64, + use_fallback=True, + optimize_fallback=False, ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] out2 = m(inputs_) @@ -136,4 +140,3 @@ def test_primitive_tensor_product_cuda_vs_fx( for g1, g2 in zip(double_grad1, double_grad2): torch.testing.assert_close(g1, g2.to(dtype), atol=100 * tol, rtol=100 * tol) - From ec1eb27425f7aea8f1289dffde94f180ae0925b2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 01:39:39 -0800 Subject: [PATCH 25/44] format with black --- .../primitives/equivariant_tensor_product.py | 65 +++++++++---------- .../primitives/symmetric_tensor_product.py | 34 +++++----- .../primitives/tensor_product.py | 53 ++++++++------- .../primitives/transpose.py | 26 ++++---- .../tests/operations/linear_test.py | 6 +- .../tests/operations/tp_channel_wise_test.py | 4 +- .../operations/tp_fully_connected_test.py | 4 +- .../equivariant_tensor_product_test.py | 22 +++---- .../symmetric_tensor_product_test.py | 10 ++- .../tests/primitives/tensor_product_test.py | 9 ++- 10 files changed, 121 insertions(+), 112 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cacbd9c..1808ada 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -26,41 +26,33 @@ def __init__(self, tp): super().__init__() self.tp = tp + class Transpose1Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) return ret - + + class Transpose2Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) return ret + class Transpose3Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) ret[2] = self.tp[2](ret[2]) return ret + class Transpose4Dispatcher(Dispatcher): - def forward( - self, - inputs: List[torch.Tensor] - ): + def forward(self, inputs: List[torch.Tensor]): ret = inputs.copy() ret[0] = self.tp[0](ret[0]) ret[1] = self.tp[1](ret[1]) @@ -68,7 +60,14 @@ def forward( ret[3] = self.tp[3](ret[3]) return ret -TRANSPOSE_DISPATCHERS = [Transpose1Dispatcher, Transpose2Dispatcher, Transpose3Dispatcher, Transpose4Dispatcher] + +TRANSPOSE_DISPATCHERS = [ + Transpose1Dispatcher, + Transpose2Dispatcher, + Transpose3Dispatcher, + Transpose4Dispatcher, +] + class TPDispatcher(Dispatcher): def forward( @@ -80,9 +79,9 @@ def forward( # TODO: at some point we will have kernel for this assert len(inputs) >= 1 inputs[0] = inputs[0][indices] - return self.tp(inputs) + return self.tp(inputs) + - class SymmetricTPDispatcher(Dispatcher): def forward( self, @@ -91,12 +90,13 @@ def forward( ) -> torch.Tensor: assert indices is None return self.tp(inputs[0]) - + + class IWeightedSymmetricTPDispatcher(Dispatcher): def forward( - self, - inputs: List[torch.Tensor], - indices: Optional[torch.Tensor] = None, + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: x0, x1 = inputs if indices is None: @@ -104,11 +104,10 @@ def forward( x0.ndim == 2, f"Expected x0 to have shape (batch, dim), got {x0.shape}", ) - indices = torch.arange( - x1.shape[0], dtype=torch.int32, device=x1.device - ) + indices = torch.arange(x1.shape[0], dtype=torch.int32, device=x1.device) return self.tp(x0, indices, x1) + class EquivariantTensorProduct(torch.nn.Module): r"""Equivariant tensor product. @@ -146,7 +145,7 @@ class EquivariantTensorProduct(torch.nn.Module): ... [0., 0., 0., 0., 0., 0.]]) """ - + def __init__( self, e: cue.EquivariantTensorProduct, @@ -184,19 +183,19 @@ def __init__( source=layout_used, target=input_expected.layout, device=device, - use_fallback = use_fallback + use_fallback=use_fallback, ) ) # script() requires literal addressing and fails to eliminate dead branches - self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in)-1](transpose_in) - + self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in) - 1](transpose_in) + self.transpose_out = cuet.TransposeIrrepsLayout( e.output.irreps, source=e.output.layout, target=layout_out, device=device, - use_fallback = use_fallback + use_fallback=use_fallback, ) if any(d.num_operands != e.num_inputs + 1 for d in e.ds): @@ -228,7 +227,7 @@ def __init__( e.ds[0], device=device, math_dtype=math_dtype, - use_fallback = use_fallback, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index f02be9c..bb30e39 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -57,7 +57,7 @@ def __init__( d0, device=device, math_dtype=math_dtype, - use_fallback = use_fallback, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) else: @@ -85,13 +85,11 @@ def __init__( descriptors, device=device, math_dtype=math_dtype, - use_fallback = use_fallback, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) - def forward( - self, x0: torch.Tensor - ) -> torch.Tensor: + def forward(self, x0: torch.Tensor) -> torch.Tensor: r""" Perform the forward pass of the indexed symmetric tensor product operation. @@ -109,7 +107,7 @@ def forward( out = self.f( torch.ones((1, 1), dtype=x0.dtype, device=x0.device), torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device), - x0 + x0, ) if self.f0 is not None: out += self.f0([]) @@ -153,7 +151,7 @@ def __init__( self.x1_size = d.operands[1].size self.x2_size = d.operands[-1].size self.has_cuda = False - + if use_fallback is None or not use_fallback: try: self.f = CUDAKernel(descriptors, device, math_dtype) @@ -163,7 +161,7 @@ def __init__( logger.info(f"Failed to initialize CUDA implementation: {e}") except ImportError as e: logger.warning(f"Failed to initialize CUDA implementation: {e}") - + if use_fallback is None or use_fallback: self.f = FallbackImpl( descriptors, @@ -172,11 +170,15 @@ def __init__( optimize_fallback=optimize_fallback, ) else: - raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available") + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available" + ) def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.has_cuda is not None else "(without CUDA kernel)" + "(with CUDA kernel)" + if self.has_cuda is not None + else "(without CUDA kernel)" ) return f"IWeightedSymmetricTensorProduct({has_cuda_kernel})" @@ -212,16 +214,13 @@ def forward( ) shape = broadcast_shapes([i0.shape, x1.shape[:-1]]) i0 = i0.expand(shape).reshape((prod(shape),)) - x1 = x1.expand(shape + (x1.shape[-1],)).reshape( - (prod(shape), x1.shape[-1]) - ) - - out = self.f(x0, i0, x1) + x1 = x1.expand(shape + (x1.shape[-1],)).reshape((prod(shape), x1.shape[-1])) + + out = self.f(x0, i0, x1) out = out.reshape(shape + (self.x2_size,)) return out - def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): if len(descriptors) == 0: raise ValueError("stps must contain at least one STP.") @@ -359,6 +358,5 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) - for f in self.fs + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 0540eda..e0d5e98 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -23,12 +23,14 @@ logger = logging.getLogger(__name__) + def prod(numbers: List[int]): product = 1 for num in numbers: product *= num return product + def broadcast_shapes(shapes: List[List[int]]): if torch.jit.is_scripting(): max_len = 0 @@ -47,15 +49,23 @@ def broadcast_shapes(shapes: List[List[int]]): if isinstance(shape, (tuple, list)): for i in range(-1, -1 - len(shape), -1): if shape[i] < 0: - raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})" - .format(shape[i], shape[i])) + raise RuntimeError( + "Trying to create tensor with negative dimension ({}): ({})".format( + shape[i], shape[i] + ) + ) if shape[i] == 1 or shape[i] == result[i]: continue if result[i] != 1: - raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape") + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) result[i] = shape[i] else: - raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape) + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) return torch.Size(result) else: return torch.functional.broadcast_shapes(*shapes) @@ -76,6 +86,7 @@ class TensorProduct(torch.nn.Module): RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. """ + def __init__( self, descriptor: stp.SegmentedTensorProduct, @@ -91,8 +102,8 @@ def __init__( math_dtype = torch.get_default_dtype() self.f = None self.has_cuda = False - - if not use_fallback == True: + + if not use_fallback == True: try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) self.has_cuda = True @@ -107,9 +118,11 @@ def __init__( "pip install cuequivariance-ops-torch-cu11\n" "pip install cuequivariance-ops-torch-cu12" ) - - if use_fallback == False: - raise RuntimeError("`use_fallback` is `False` and no CUDA kernel is available!") + + if use_fallback == False: + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available!" + ) else: self.f = _tensor_product_fx( descriptor, device, math_dtype, optimize_fallback is True @@ -118,7 +131,7 @@ def __init__( warnings.warn( "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." - ) + ) self._optimize_fallback = optimize_fallback def __repr__(self): @@ -216,9 +229,7 @@ def _tensor_product_fx( seg_shape = descriptor.get_segment_shape(-1, path) outputs += [ - out.reshape( - out.shape[: out.ndim - len(seg_shape)] + (prod(seg_shape),) - ) + out.reshape(out.shape[: out.ndim - len(seg_shape)] + (prod(seg_shape),)) ] if len(outputs) == 0: @@ -305,7 +316,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu self.module = module self.descriptor = descriptor - def forward(self, args:List[torch.Tensor]): + def forward(self, args: List[torch.Tensor]): if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): for oid, arg in enumerate(args): torch._assert( @@ -417,7 +428,7 @@ def _tensor_product_cuda( elif descriptor.num_operands == 4: return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) - + def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor: # Make x have shape (Z, x.shape[-1]) or (x.shape[-1],) if prod(leading_shape) > 1 and prod(x.shape[:-1]) == 1: @@ -464,10 +475,7 @@ def __init__( def __repr__(self) -> str: return f"FusedTensorProductOp3({self.descriptor} (output last operand))" - def forward( - self, - inputs: List[torch.Tensor] - ) -> torch.Tensor: + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1 = self._perm(inputs[0], inputs[1]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -522,10 +530,7 @@ def __init__( def __repr__(self) -> str: return f"FusedTensorProductOp4({self.descriptor} (output last operand))" - def forward( - self, - inputs: List[torch.Tensor] - ) -> torch.Tensor: + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim @@ -545,6 +550,7 @@ def forward( return out.reshape(shape + (out.shape[-1],)) + class TensorProductUniform1d(torch.nn.Module): def __init__( self, @@ -571,6 +577,7 @@ def __init__( math_dtype=math_dtype, ).to(device=device) + class TensorProductUniform3x1d(TensorProductUniform1d): def __repr__(self): return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 252e45e..9e7156e 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -36,19 +36,21 @@ def __init__( source: cue.IrrepsLayout, target: cue.IrrepsLayout, device: Optional[torch.device] = None, - use_fallback: Optional[bool] = False + use_fallback: Optional[bool] = False, ): super().__init__() if (source, target) == (cue.mul_ir, cue.ir_mul): self.f = TransposeSegments( - [(mul, ir.dim) for mul, ir in irreps], device=device, - use_fallback = use_fallback + [(mul, ir.dim) for mul, ir in irreps], + device=device, + use_fallback=use_fallback, ) elif (source, target) == (cue.ir_mul, cue.mul_ir): self.f = TransposeSegments( - [(ir.dim, mul) for mul, ir in irreps], device=device, - use_fallback = use_fallback + [(ir.dim, mul) for mul, ir in irreps], + device=device, + use_fallback=use_fallback, ) else: self.f = torch.nn.Identity() @@ -61,9 +63,7 @@ def __init__( def __repr__(self): return f"TransposeIrrepsLayout({self.source} -> {self.target})" - def forward( - self, x: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" Perform the transposition. @@ -82,8 +82,10 @@ def forward( class TransposeSegments(torch.nn.Module): def __init__( - self, segments: list[tuple[int, int]], device: Optional[torch.device] = None, - use_fallback: Optional[bool] = False + self, + segments: list[tuple[int, int]], + device: Optional[torch.device] = None, + use_fallback: Optional[bool] = False, ): super().__init__() @@ -105,9 +107,7 @@ def __init__( def __repr__(self): return "TransposeSegments()" - def forward( - self, x: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform the transposition of the input tensor using either a CUDA kernel or a PyTorch fallback. diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index d0d8a41..26b9e5b 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -45,7 +45,7 @@ def test_linear_fwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, - use_fallback=False + use_fallback=False, ) linear_fx = cuet.Linear( irreps_in, @@ -54,7 +54,7 @@ def test_linear_fwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, - use_fallback=True + use_fallback=True, ) x = torch.randn(10, irreps_in.dim, dtype=torch.float64).cuda() @@ -90,7 +90,7 @@ def test_linear_bwd_bwd( shared_weights=shared_weights, device="cuda", dtype=torch.float64, - use_fallback=use_fallback + use_fallback=use_fallback, ) # reset the seed to ensure the same initialization diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 4e12f73..c48e1e1 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -59,7 +59,7 @@ def test_channel_wise( layout=layout, device="cuda", dtype=torch.float64, - use_fallback=True + use_fallback=True, ) x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() @@ -104,7 +104,7 @@ def test_channel_wise_bwd_bwd(): layout=cue.ir_mul, device="cuda", dtype=torch.float64, - use_fallback=use_fallback + use_fallback=use_fallback, ) (grad1, grad2, grad3) = torch.autograd.grad( diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index 49c65a0..d9b19b4 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -47,7 +47,7 @@ def test_fully_connected( layout=layout, device="cuda", dtype=torch.float64, - use_fallback=use_fallback + use_fallback=use_fallback, ) x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() @@ -74,7 +74,7 @@ def test_compile(): irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), layout=cue.mul_ir, device=device, - use_fallback=False + use_fallback=False, ) m_compile = torch.compile(m, fullgraph=True) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 8f3ea2f..59d44b8 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -77,7 +77,7 @@ def test_performance_cuda_vs_fx( math_dtype=math_dtype, use_fallback=False, ) - + m1 = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, @@ -141,11 +141,7 @@ def test_precision_cuda_vs_fx( for inp in e.inputs ] m = cuet.EquivariantTensorProduct( - e, - layout=cue.ir_mul, - device=device, - math_dtype=math_dtype, - use_fallback=False + e, layout=cue.ir_mul, device=device, math_dtype=math_dtype, use_fallback=False ) y0 = m(inputs) @@ -174,9 +170,9 @@ def test_compile( ): device = torch.device("cuda:0") - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, - use_fallback=False, - device="cuda") + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, use_fallback=False, device="cuda" + ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs @@ -186,6 +182,7 @@ def test_compile( res_script = m_compile(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + @pytest.mark.parametrize("e", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) def test_script( @@ -198,9 +195,9 @@ def test_script( device = torch.device("cuda:0") - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, - use_fallback=False, - device="cuda") + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, use_fallback=False, device="cuda" + ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs @@ -209,4 +206,3 @@ def test_script( m_script = torch.jit.script(m) res_script = m_script(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) - diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 909d71d..f5ab6aa 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -78,7 +78,11 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( out1 = m(x0, i0, x1) m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=torch.float64, device=device, use_fallback=True, optimize_fallback=True + ds, + math_dtype=torch.float64, + device=device, + use_fallback=True, + optimize_fallback=True, ) out2 = m(x0_, i0, x1_) @@ -122,7 +126,9 @@ def test_math_dtype( ds = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ).ds - m = cuet.IWeightedSymmetricTensorProduct(ds, math_dtype=math_dtype, device=device, use_fallback=False) + m = cuet.IWeightedSymmetricTensorProduct( + ds, math_dtype=math_dtype, device=device, use_fallback=False + ) x0 = torch.randn((20, m.x0_size), dtype=dtype, device=device) i0 = torch.randint(0, m.x0_size, (1000,), dtype=torch.int32, device=device) x1 = torch.randn((i0.size(0), m.x1_size), dtype=dtype, device=device) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index c7bf8e2..e4ba7cc 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -113,9 +113,13 @@ def test_primitive_tensor_product_cuda_vs_fx( ) m = torch.jit.script(m) out1 = m(inputs) - + m = cuet.TensorProduct( - d, device=device, math_dtype=torch.float64, use_fallback=True, optimize_fallback=False + d, + device=device, + math_dtype=torch.float64, + use_fallback=True, + optimize_fallback=False, ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] out2 = m(inputs_) @@ -136,4 +140,3 @@ def test_primitive_tensor_product_cuda_vs_fx( for g1, g2 in zip(double_grad1, double_grad2): torch.testing.assert_close(g1, g2.to(dtype), atol=100 * tol, rtol=100 * tol) - From faf235eb7e68484b4e37a56a5d09889ddb237119 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 01:58:46 -0800 Subject: [PATCH 26/44] fix tests --- cuequivariance_torch/tests/operations/linear_test.py | 4 ++++ .../tests/operations/tp_channel_wise_test.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index 26b9e5b..f06c8e5 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -38,6 +38,7 @@ def test_linear_fwd( layout: cue.IrrepsLayout, shared_weights: bool, ): + torch.manual_seed(0) linear = cuet.Linear( irreps_in, irreps_out, @@ -47,6 +48,8 @@ def test_linear_fwd( dtype=torch.float64, use_fallback=False, ) + + torch.manual_seed(0) linear_fx = cuet.Linear( irreps_in, irreps_out, @@ -83,6 +86,7 @@ def test_linear_bwd_bwd( ): outputs = dict() for use_fallback in [True, False]: + torch.manual_seed(0) linear = cuet.Linear( irreps_in, irreps_out, diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index c48e1e1..d3628ab 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -89,9 +89,6 @@ def test_channel_wise_bwd_bwd(): x2 = torch.randn( 32, irreps2.dim, device="cuda", requires_grad=True, dtype=torch.float64 ) - w = torch.randn( - m.weight_numel, device="cuda", requires_grad=True, dtype=torch.float64 - ) outputs = {} for use_fallback in [True, False]: @@ -107,6 +104,11 @@ def test_channel_wise_bwd_bwd(): use_fallback=use_fallback, ) + torch.manual_seed(0) + w = torch.randn( + m.weight_numel, device="cuda", requires_grad=True, dtype=torch.float64 + ) + (grad1, grad2, grad3) = torch.autograd.grad( m(x1, x2, w).pow(2).sum(), (x1, x2, w), create_graph=True ) From c476af9599a96a21a5650333838d810d30d5c2ec Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:11:36 -0800 Subject: [PATCH 27/44] fix missing parenthesis --- .../cuequivariance_torch/primitives/tensor_product.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index e0d5e98..ac8cf42 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -541,7 +541,7 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) - if not torch.jit.is_scripting and not torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) @@ -607,7 +607,6 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: class TensorProductUniform4x1d(TensorProductUniform1d): - def __repr__(self): return f"TensorProductUniform4x1d({self.descriptor} (output last operand))" From 994b8d9640ee752938155f930dfa8161d0718dcc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:11:53 -0800 Subject: [PATCH 28/44] fix tests: increase torch._dynamo.config.cache_size_limit --- .../tests/primitives/equivariant_tensor_product_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 59d44b8..5001a0c 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -16,11 +16,14 @@ import pytest import torch +import torch._dynamo import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors +torch._dynamo.config.cache_size_limit = 100 + def make_descriptors(): # This ETP will trigger the fusedTP kernel @@ -171,7 +174,7 @@ def test_compile( device = torch.device("cuda:0") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device="cuda" + e, layout=cue.mul_ir, use_fallback=False, device="cuda", math_dtype=math_dtype ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) @@ -192,11 +195,10 @@ def test_script( atol: float, rtol: float, ): - device = torch.device("cuda:0") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device="cuda" + e, layout=cue.mul_ir, use_fallback=False, device="cuda", math_dtype=math_dtype ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) From f240eb8ca5a621adc2f1a1385f89427922ef8717 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:24:25 -0800 Subject: [PATCH 29/44] fix docstring tests --- .../layers/tp_conv_fully_connected.py | 22 +------------------ .../operations/symmetric_contraction.py | 5 +++-- .../primitives/equivariant_tensor_product.py | 18 ++++++--------- 3 files changed, 11 insertions(+), 34 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index e4b3a58..c3eca17 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -72,27 +72,7 @@ class FullyConnectedTensorProductConv(nn.Module): >>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, ... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul).cuda() >>> conv1 - FullyConnectedTensorProductConv( - (tp): FullyConnectedTensorProduct( - shared_weights=False, internal_weights=False, weight_numel=64 - (f): EquivariantTensorProduct( - EquivariantTensorProduct(64x0e x 4x0e+4x1o x 0e+1o -> 4x0e+4x1o) - (transpose_in): ModuleList( - (0-2): 3 x TransposeIrrepsLayout((irrep,mul) -> (irrep,mul)) - ) - (transpose_out): TransposeIrrepsLayout((irrep,mul) -> (irrep,mul)) - (tp): TensorProduct(uvw,iu,jv,kw+ijk sizes=64,16,4,16 num_segments=4,2,2,2 num_paths=4 i={1, 3} j={1, 3} k={1, 3} u=4 v=1 w=4 (with CUDA kernel)) - ) - ) - (batch_norm): BatchNorm(4x0e+4x1o, layout=(irrep,mul), eps=1e-05, momentum=0.1) - (mlp): Sequential( - (0): Linear(in_features=6, out_features=16, bias=True) - (1): ReLU() - (2): Linear(in_features=16, out_features=16, bias=True) - (3): ReLU() - (4): Linear(in_features=16, out_features=64, bias=True) - ) - ) + FullyConnectedTensorProductConv(...) >>> # out = conv1(src_features, edge_sh, edge_emb, graph) **Case 2**: If edge_emb is constructed by concatenating scalar features from diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index fac5739..e1b8d66 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -71,14 +71,15 @@ class SymmetricContraction(torch.nn.Module): ... layout_out=cue.mul_ir, ... original_mace=True, ... dtype=torch.float64, + ... device=torch.device("cuda"), ... ) Then the execution is as follows: - >>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64) + >>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64).cuda() >>> # with node_attrs_index being the index version of node_attrs, sth like: >>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int() - >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32) + >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32).cuda() >>> # OLD CALL: >>> # symmetric_contractions_old(node_feats, node_attrs) >>> # NEW CALL: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 1808ada..48f4cdd 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -127,23 +127,19 @@ class EquivariantTensorProduct(torch.nn.Module): >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) - >>> w = torch.ones(e.inputs[0].irreps.dim) - >>> x1 = torch.ones(17, e.inputs[1].irreps.dim) - >>> x2 = torch.ones(17, e.inputs[2].irreps.dim) - >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul) + >>> w = torch.ones(e.inputs[0].irreps.dim).cuda() + >>> x1 = torch.ones(17, e.inputs[1].irreps.dim).cuda() + >>> x2 = torch.ones(17, e.inputs[2].irreps.dim).cuda() + >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=torch.device("cuda")) >>> tp([w, x1, x2]) - tensor([[0., 0., 0., 0., 0., 0.], - ... - [0., 0., 0., 0., 0., 0.]]) + tensor([[0., 0., 0., 0., 0., 0.],...) You can optionally index the first input tensor: - >>> w = torch.ones(3, e.inputs[0].irreps.dim) + >>> w = torch.ones(3, e.inputs[0].irreps.dim).cuda() >>> indices = torch.randint(3, (17,)) >>> tp([w, x1, x2], indices=indices) - tensor([[0., 0., 0., 0., 0., 0.], - ... - [0., 0., 0., 0., 0., 0.]]) + tensor([[0., 0., 0., 0., 0., 0.],...) """ def __init__( From fbfb9d084187e82f95bd37fa216fdafcbcb3f76c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:29:01 -0800 Subject: [PATCH 30/44] replace == by is --- .../cuequivariance_torch/primitives/tensor_product.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index ac8cf42..f5f783d 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -103,7 +103,7 @@ def __init__( self.f = None self.has_cuda = False - if not use_fallback == True: + if use_fallback is None or use_fallback is False: try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) self.has_cuda = True @@ -119,7 +119,7 @@ def __init__( "pip install cuequivariance-ops-torch-cu12" ) - if use_fallback == False: + if use_fallback is False: raise RuntimeError( "`use_fallback` is `False` and no CUDA kernel is available!" ) From dc20be5f6285d0e6e9524748fcd5936ef543ecff Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:46:20 -0800 Subject: [PATCH 31/44] clean use_fallback conditions --- .../primitives/symmetric_tensor_product.py | 13 +++++----- .../primitives/tensor_product.py | 14 ++++++----- .../primitives/transpose.py | 25 ++++++++++--------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index bb30e39..4b1485d 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -152,7 +152,7 @@ def __init__( self.x2_size = d.operands[-1].size self.has_cuda = False - if use_fallback is None or not use_fallback: + if use_fallback is None or use_fallback is False: try: self.f = CUDAKernel(descriptors, device, math_dtype) self.has_cuda = True @@ -162,17 +162,18 @@ def __init__( except ImportError as e: logger.warning(f"Failed to initialize CUDA implementation: {e}") - if use_fallback is None or use_fallback: + if use_fallback is False and not self.has_cuda: + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available!" + ) + + if self.f is None: self.f = FallbackImpl( descriptors, device, math_dtype=math_dtype, optimize_fallback=optimize_fallback, ) - else: - raise RuntimeError( - "`use_fallback` is `False` and no CUDA kernel is available" - ) def __repr__(self): has_cuda_kernel = ( diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index f5f783d..23392eb 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -107,7 +107,6 @@ def __init__( try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) self.has_cuda = True - return except NotImplementedError as e: logger.info(f"CUDA implementation not available: {e}") except ImportError as e: @@ -119,19 +118,22 @@ def __init__( "pip install cuequivariance-ops-torch-cu12" ) - if use_fallback is False: + if use_fallback is False and not self.has_cuda: raise RuntimeError( "`use_fallback` is `False` and no CUDA kernel is available!" ) - else: - self.f = _tensor_product_fx( - descriptor, device, math_dtype, optimize_fallback is True - ) + + if self.f is None: if optimize_fallback is None: + optimize_fallback = False warnings.warn( "The fallback method is used but it has not been optimized. " "Consider setting optimize_fallback=True when creating the TensorProduct module." ) + + self.f = _tensor_product_fx( + descriptor, device, math_dtype, optimize_fallback + ) self._optimize_fallback = optimize_fallback def __repr__(self): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 9e7156e..848920c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -92,16 +92,20 @@ def __init__( info = _transpose_info(segments, device=device) if info is not None: - try: - import cuequivariance_ops_torch # noqa: F401 - except ImportError: - self.f_cuda = None - else: - self.f_cuda = _transpose(info).to(device=device) - if use_fallback: + if use_fallback is False or use_fallback is None: + try: + import cuequivariance_ops_torch # noqa: F401 + except ImportError: + self.f = None + else: + self.f = _transpose(info).to(device=device) + + if use_fallback is False and self.f is None: + raise RuntimeError("CUDA kernel not available for TransposeSegments.") + + if self.f is None: self.f = _transpose_segments_fx(segments).to(device=device) else: - self.f_cuda = torch.nn.Identity() self.f = torch.nn.Identity() def __repr__(self): @@ -130,10 +134,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: RuntimeError If `use_fallback` is `False` and a CUDA kernel is not available or the input is not on CUDA. """ - if self.f_cuda is not None: - return self.f_cuda(x) - else: - return self.f(x) + return self.f(x) def _transpose_segments_fx(segments: list[tuple[int, int]]) -> torch.nn.Module: From 4b201c35d6c90f383b91e692bef2316981b538e9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:47:55 -0800 Subject: [PATCH 32/44] fix --- .../cuequivariance_torch/primitives/transpose.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 848920c..b23777a 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -90,13 +90,14 @@ def __init__( super().__init__() info = _transpose_info(segments, device=device) + self.f = None if info is not None: if use_fallback is False or use_fallback is None: try: import cuequivariance_ops_torch # noqa: F401 except ImportError: - self.f = None + pass else: self.f = _transpose(info).to(device=device) From b5b59b8daf3b3edcbe01127fc3764bafe7f2edac Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 6 Dec 2024 02:49:06 -0800 Subject: [PATCH 33/44] fix --- .../cuequivariance_torch/primitives/symmetric_tensor_product.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 4b1485d..b62d1b8 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -150,7 +150,9 @@ def __init__( self.x0_size = d.operands[0].size self.x1_size = d.operands[1].size self.x2_size = d.operands[-1].size + self.has_cuda = False + self.f = None if use_fallback is None or use_fallback is False: try: From 72baf17fe6ef76a6c4e5fa7ad84234f6b334e5d1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 6 Dec 2024 19:09:44 -0800 Subject: [PATCH 34/44] Export test added, scripting fallback attempt Signed-off-by: Boris Fomitchev --- .../primitives/symmetric_tensor_product.py | 3 +- .../primitives/tensor_product.py | 39 ++- .../equivariant_tensor_product_test.py | 38 ++- .../tests/primitives/utils.py | 267 ++++++++++++++++++ 4 files changed, 338 insertions(+), 9 deletions(-) create mode 100644 cuequivariance_torch/tests/primitives/utils.py diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index bb30e39..2b9b564 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -348,6 +348,7 @@ def __init__( d, device=device, math_dtype=math_dtype, + use_fallback=True, optimize_fallback=optimize_fallback, ) for d in stps @@ -358,5 +359,5 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) for f in self.fs + f([x0[i0]] + [x1] * (f.num_operands - 2)) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index e0d5e98..064e286 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -102,7 +102,8 @@ def __init__( math_dtype = torch.get_default_dtype() self.f = None self.has_cuda = False - + self.num_operands = descriptor.num_operands + if not use_fallback == True: try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) @@ -289,7 +290,7 @@ def __init__(self, descriptor: stp.SegmentedTensorProduct): ), ) - def forward(self, *args): + def forward(self, args: List[torch.Tensor]): shape = broadcast_shapes([arg.shape[:-1] for arg in args]) output = torch.zeros( shape + (descriptor.operands[-1].size,), @@ -310,10 +311,37 @@ def forward(self, *args): return _Wrapper(graphmod, descriptor) +class _Caller(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.module = module + +class _NoArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module() + +class _OneArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0]) + +class _TwoArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1]) + +class _ThreeArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2]) + +class _FourArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2], args[3]) + +CALL_DISPATCHERS = [_NoArgCaller, _OneArgCaller, _TwoArgCaller, _ThreeArgCaller, _FourArgCaller] + class _Wrapper(torch.nn.Module): def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProduct): super().__init__() - self.module = module + self.module = CALL_DISPATCHERS[descriptor.num_operands-1](module) self.descriptor = descriptor def forward(self, args: List[torch.Tensor]): @@ -336,8 +364,7 @@ def forward(self, args: List[torch.Tensor]): ) for arg in args ] - - out = self.module(*args) + out = self.module(args) return out.reshape(shape + (out.shape[-1],)) @@ -541,7 +568,7 @@ def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) - if not torch.jit.is_scripting and not torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): logger.debug( f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" ) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 59d44b8..a6099a1 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -21,6 +21,9 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +from utils import ( + module_with_mode, +) def make_descriptors(): # This ETP will trigger the fusedTP kernel @@ -171,7 +174,7 @@ def test_compile( device = torch.device("cuda:0") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device="cuda" + e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device="cuda" ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) @@ -196,7 +199,7 @@ def test_script( device = torch.device("cuda:0") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device="cuda" + e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device="cuda" ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) @@ -206,3 +209,34 @@ def test_script( m_script = torch.jit.script(m) res_script = m_script(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + +# export_modes = ["onnx", "onnx_dynamo", "trt", "torch_trt", "jit"] +export_modes = ["trt","onnx"] + +@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +@pytest.mark.parametrize("mode", export_modes) + +def test_export( + e: cue.EquivariantTensorProduct, + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, + mode: str, + tmp_path +): + + device = torch.device("cuda:0") + + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device="cuda" + ) + inputs = [ + torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) + for inp in e.inputs + ] + res = m(inputs) + m_script = module_with_mode(mode, m, [inputs], math_dtype, tmp_path) + res_script = m_script(inputs) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/primitives/utils.py b/cuequivariance_torch/tests/primitives/utils.py new file mode 100644 index 0000000..ce646fd --- /dev/null +++ b/cuequivariance_torch/tests/primitives/utils.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os +import pytest +import torch +from typing import Sequence + +try: + import onnx # noqa: F401 + import onnxscript # noqa: F401 + import onnxruntime # noqa: F401 + import cuequivariance_ops_torch.onnx # noqa: F401 + from cuequivariance_ops_torch.tensorrt import register_plugins + + ONNX_AVAILABLE = True +except Exception: + ONNX_AVAILABLE = False + + +try: + import torch_tensorrt + + TORCH_TRT_AVAILABLE = True +except Exception: + TORCH_TRT_AVAILABLE = False + + +def verify_onnx(module, onnx_module, inputs, dtype): + if dtype != torch.float32: + pytest.skip("onnxrt only checked for float32") + from onnxruntime import SessionOptions + from onnxruntime_extensions import get_library_path + from torch.onnx.verification import ( + _compare_onnx_pytorch_model, + VerificationOptions, + ) + + original_init = SessionOptions.__init__ + + def new_init(self): + original_init(self) + try: + self.register_custom_ops_library(get_library_path()) + except Exception: + pass + + SessionOptions.__init__ = new_init + _compare_onnx_pytorch_model( + module, onnx_module, tuple(inputs), None, None, VerificationOptions() + ) + SessionOptions.__init__ = original_init + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def verify_trt(module, onnx_module, inputs, dtype): + import tensorrt + from pkg_resources import parse_version + + if parse_version(tensorrt.__version__) < parse_version("10.3.0"): + pytest.skip("TRT < 10.3.0 is not supported!") + if dtype == torch.float64: + pytest.skip("TRT does not support float64") + + from polygraphy.backend.trt import ( + engine_from_network, + network_from_onnx_path, + TrtRunner, + CreateConfig, + ) + from polygraphy.backend.onnxrt import OnnxrtRunner + from polygraphy.comparator import Comparator, DataLoader + from onnxruntime import InferenceSession, SessionOptions + from onnxruntime_extensions import get_library_path + + register_plugins() + + network = network_from_onnx_path(onnx_module) + trt_engine = engine_from_network(network, config=CreateConfig()) + + if dtype != torch.float32: + pytest.skip("Comparator only supports float32") + + # Create runners for ONNX and TRT models + trt_runner = TrtRunner(trt_engine) + + options = SessionOptions() + options.register_custom_ops_library(get_library_path()) + onnx_runner = OnnxrtRunner(InferenceSession(onnx_module, sess_options=options)) + + results = Comparator.run([trt_runner, onnx_runner], data_loader=DataLoader()) + Comparator.compare_accuracy(results) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def module_with_mode( + mode, + module, + inputs, + math_dtype, + tmp_path, + grad_modes=["eager", "compile", "jit", "export"], +): + if isinstance(inputs[0], list): + dtype = inputs[0][0].dtype + else: + dtype = inputs[0].dtype + if mode in ["trt", "torch_trt", "onnx", "onnx_dynamo", "export"]: + if not ONNX_AVAILABLE: + pytest.skip("ONNX not available!") + if dtype == torch.float64 or math_dtype == torch.float64: + pytest.skip("TRT/ORT do not support float64") + + with torch.set_grad_enabled(mode in grad_modes): + if mode == "compile": + import sys + + if sys.version_info.major == 3 and sys.version_info.minor >= 12: + pytest.skip("torch dynamo needs cpy <= 3.11") + module = torch.compile(module) + elif mode == "fx": + module = torch.fx.symbolic_trace(module) + elif mode == "jit": + module = torch.jit.trace(module, inputs) + fname = os.path.join(tmp_path, "test.ts") + torch.jit.save(module, fname) + module = torch.jit.load(fname) + elif mode == "export": + exp_program = torch.export.export(module, tuple(inputs)) + fname = os.path.join(tmp_path, "test.pt2") + torch.export.save(exp_program, fname) + del exp_program + module = torch.export.load(fname).module() + elif mode == "torch_trt": + if not TORCH_TRT_AVAILABLE: + pytest.skip("torch_tensorrt is not installed!") + register_plugins() + exp_program = torch_tensorrt.dynamo.trace(module, inputs) + module = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + require_full_compilation=True, + min_block_size=1, + enabled_precisions={torch.float32, dtype}, + # dryrun=True + ) + elif mode == "onnx" or mode == "trt": + try: + onnx_path = os.path.join(tmp_path, "test.onnx") + torch.onnx.export( + module, tuple(inputs), onnx_path, opset_version=17, verbose=False + ) + if mode == "trt": + verify_trt(module, onnx_path, inputs, dtype) + else: + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX/TRT is not available") + + elif mode == "onnx_dynamo": + try: + from cuequivariance_ops_torch.onnx import ( + cuequivariance_ops_torch_onnx_registry, + ) + + export_options = torch.onnx.ExportOptions( + onnx_registry=cuequivariance_ops_torch_onnx_registry + ) + onnx_program = torch.onnx.dynamo_export( + module, *inputs, export_options=export_options + ) + onnx_path = os.path.join(tmp_path, "test.onnx") + onnx_program.save(onnx_path) + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX is not available") + elif mode == "eager": + pass + else: + raise ValueError(f"No such mode: {mode}") + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + return module + + +def create_random_tensor_2d(batch_size, stride, requires_grad, dtype, is_shared): + data = torch.randn( + (stride,) if is_shared else (batch_size, stride), + dtype=dtype, + device="cuda", + ).requires_grad_(requires_grad) + + return data + + +def maybe_detach_and_to(tensor, *args, **kwargs): + if tensor is not None: + return tensor.clone().detach().to(*args, **kwargs) + return None + + +def run_fwd_test(module, x: Sequence): + with torch.no_grad(): + out = module(*x) + test_output = [maybe_detach_and_to(out, dtype=torch.float32)] + return test_output + + +def run_fwd_bwd_test(module, x: Sequence): + out = module(*x) + + loss = out.sum() + loss.backward() + + test_output = [maybe_detach_and_to(out, dtype=torch.float32)] + test_output.extend([maybe_detach_and_to(t.grad, dtype=torch.float32) for t in x]) + + return test_output + + +def run_bwd_bwd_test(module, x: Sequence): + test_outputs = [] + out = module(*x) + grads = torch.autograd.grad(out.pow(2).sum(), x, create_graph=True) + test_outputs.extend([maybe_detach_and_to(g, dtype=torch.float32) for g in grads]) + loss = sum([g.sum() for g in grads]) + loss.backward() + test_outputs.extend([maybe_detach_and_to(t.grad, dtype=torch.float32) for t in x]) + return test_outputs + + +def assert_close_modules(m_test, m_ref, inputs_test, procedure, tol_dict): + outs_test = procedure(m_test, inputs_test) + + inputs_ref = [ + x.clone() + .detach() + .to(device="cuda", dtype=torch.float32) + .requires_grad_(x.requires_grad) + for x in inputs_test + ] + outs_ref = procedure(m_ref, inputs_ref) + for out_test, out_ref in zip(outs_test, outs_ref): + torch.testing.assert_close(out_test, out_ref, **tol_dict) + + +tol_dict = { + # we compare against double for precision reasons + # hence FP64 and FP32 threshold are the same + (torch.float64, torch.float64): {"atol": 1e-9, "rtol": 1e-5}, + (torch.float32, torch.float64): {"atol": 1e-4, "rtol": 1e-5}, + (torch.float64, torch.float32): {"atol": 1e-4, "rtol": 1e-5}, + (torch.float32, torch.float32): {"atol": 1e-4, "rtol": 1e-5}, + (torch.bfloat16, torch.float32): {"atol": 4.0, "rtol": 1e-2}, + (torch.float16, torch.float32): {"atol": 0.25, "rtol": 1e-2}, +} From 8d319290985a1f2b0b2988f8cceb0a60b7661d39 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 13:48:46 +0100 Subject: [PATCH 35/44] enable tests on cpu --- .../primitives/tensor_product.py | 62 +++++++++++++++---- .../layers/tp_conv_fully_connected_test.py | 4 +- .../tests/operations/linear_test.py | 18 ++++-- .../tests/operations/rotation_test.py | 12 ++-- .../operations/spherical_harmonics_test.py | 10 +-- .../operations/symmetric_contraction_test.py | 14 +++-- .../tests/operations/tp_channel_wise_test.py | 39 ++++++------ .../operations/tp_fully_connected_test.py | 30 +++++---- .../equivariant_tensor_product_test.py | 42 +++++++------ .../tests/primitives/script_test.py | 56 +++++++++++------ .../symmetric_tensor_product_test.py | 32 ++++++---- .../tests/primitives/tensor_product_test.py | 15 +++-- .../tests/primitives/transpose_test.py | 18 +++--- 13 files changed, 223 insertions(+), 129 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index de5b971..b436161 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -19,6 +19,7 @@ import torch import torch.fx + from cuequivariance import segmented_tensor_product as stp logger = logging.getLogger(__name__) @@ -103,7 +104,7 @@ def __init__( self.f = None self.has_cuda = False self.num_operands = descriptor.num_operands - + if use_fallback is None or use_fallback is False: try: self.f = _tensor_product_cuda(descriptor, device, math_dtype) @@ -278,9 +279,9 @@ def _tensor_product_fx( for operand in descriptor.operands[:num_inputs] ] graphmod = opt_einsum_fx.optimize_einsums_full(graphmod, example_inputs) - else: + elif num_inputs == 0: - class _no_input_or_no_paths(torch.nn.Module): + class _no_input(torch.nn.Module): def __init__(self, descriptor: stp.SegmentedTensorProduct): super().__init__() @@ -292,12 +293,9 @@ def __init__(self, descriptor: stp.SegmentedTensorProduct): ), ) - def forward(self, args: List[torch.Tensor]): - shape = broadcast_shapes([arg.shape[:-1] for arg in args]) + def forward(self): output = torch.zeros( - shape + (descriptor.operands[-1].size,), - device=device, - dtype=math_dtype, + (descriptor.operands[-1].size,), device=device, dtype=math_dtype ) for pid in range(descriptor.num_paths): output += torch.einsum( @@ -308,7 +306,12 @@ def forward(self, args: List[torch.Tensor]): ) return output - graphmod = _no_input_or_no_paths(descriptor) + graphmod = _no_input(descriptor) + + else: + raise NotImplementedError( + "No FX implementation for empty paths and non-empty inputs" + ) return _Wrapper(graphmod, descriptor) @@ -317,33 +320,66 @@ class _Caller(torch.nn.Module): def __init__(self, module: torch.nn.Module): super().__init__() self.module = module - + + class _NoArgCaller(_Caller): def forward(self, args: List[torch.Tensor]): return self.module() + class _OneArgCaller(_Caller): def forward(self, args: List[torch.Tensor]): return self.module(args[0]) + class _TwoArgCaller(_Caller): def forward(self, args: List[torch.Tensor]): return self.module(args[0], args[1]) - + + class _ThreeArgCaller(_Caller): def forward(self, args: List[torch.Tensor]): return self.module(args[0], args[1], args[2]) + class _FourArgCaller(_Caller): def forward(self, args: List[torch.Tensor]): return self.module(args[0], args[1], args[2], args[3]) -CALL_DISPATCHERS = [_NoArgCaller, _OneArgCaller, _TwoArgCaller, _ThreeArgCaller, _FourArgCaller] + +class _FiveArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2], args[3], args[4]) + + +class _SixArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2], args[3], args[4], args[5]) + + +class _SevenArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module( + args[0], args[1], args[2], args[3], args[4], args[5], args[6] + ) + + +CALL_DISPATCHERS = [ + _NoArgCaller, + _OneArgCaller, + _TwoArgCaller, + _ThreeArgCaller, + _FourArgCaller, + _FiveArgCaller, + _SixArgCaller, + _SevenArgCaller, +] + class _Wrapper(torch.nn.Module): def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProduct): super().__init__() - self.module = CALL_DISPATCHERS[descriptor.num_operands-1](module) + self.module = CALL_DISPATCHERS[descriptor.num_operands - 1](module) self.descriptor = descriptor def forward(self, args: List[torch.Tensor]): diff --git a/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py b/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py index 198c6d8..96602a3 100644 --- a/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py +++ b/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py @@ -21,7 +21,7 @@ import cuequivariance_torch as cuet from cuequivariance_torch.layers.tp_conv_fully_connected import scatter_reduce -device = torch.device("cuda:0") +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") @pytest.mark.parametrize("layout", [cue.mul_ir, cue.ir_mul]) @@ -133,7 +133,6 @@ def D(irreps, axis, angle): @pytest.mark.parametrize("reduce", ["sum", "mean", "prod", "amax", "amin"]) def test_scatter_reduce(reduce: str): - device = torch.device("cuda") src = torch.Tensor([3, 1, 0, 1, 1, 2]) index = torch.Tensor([0, 1, 2, 2, 3, 1]) @@ -153,7 +152,6 @@ def test_scatter_reduce(reduce: str): def test_scatter_reduce_empty(): - device = torch.device("cuda") src, index = torch.empty((0, 41)), torch.empty((0,)) src = src.to(device) index = index.to(device) diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index afd5632..2b78ff1 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -20,6 +20,8 @@ import cuequivariance as cue import cuequivariance_torch as cuet +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + list_of_irreps = [ cue.Irreps("SU2", "3x1/2 + 4x1"), cue.Irreps("SU2", "2x1/2 + 5x1 + 2x1/2"), @@ -37,13 +39,16 @@ def test_linear_fwd( layout: cue.IrrepsLayout, shared_weights: bool, ): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + torch.manual_seed(0) linear = cuet.Linear( irreps_in, irreps_out, layout=layout, shared_weights=shared_weights, - device="cuda", + device=device, dtype=torch.float64, use_fallback=False, ) @@ -54,7 +59,7 @@ def test_linear_fwd( irreps_out, layout=layout, shared_weights=shared_weights, - device="cuda", + device=device, dtype=torch.float64, use_fallback=True, ) @@ -83,6 +88,9 @@ def test_linear_bwd_bwd( layout: cue.IrrepsLayout, shared_weights: bool, ): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + outputs = dict() for use_fallback in [True, False]: torch.manual_seed(0) @@ -91,7 +99,7 @@ def test_linear_bwd_bwd( irreps_out, layout=layout, shared_weights=shared_weights, - device="cuda", + device=device, dtype=torch.float64, use_fallback=use_fallback, ) @@ -100,7 +108,7 @@ def test_linear_bwd_bwd( torch.manual_seed(0) x = torch.randn( - 10, irreps_in.dim, requires_grad=True, device="cuda", dtype=torch.float64 + 10, irreps_in.dim, requires_grad=True, device=device, dtype=torch.float64 ) if shared_weights: @@ -158,6 +166,6 @@ def test_linear_copy( irreps_out, layout=layout, shared_weights=shared_weights, - ).cuda() + ).to(device) copy.deepcopy(linear) diff --git a/cuequivariance_torch/tests/operations/rotation_test.py b/cuequivariance_torch/tests/operations/rotation_test.py index a73c68b..86d0230 100644 --- a/cuequivariance_torch/tests/operations/rotation_test.py +++ b/cuequivariance_torch/tests/operations/rotation_test.py @@ -17,16 +17,18 @@ import cuequivariance as cue import cuequivariance_torch as cuet +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def test_rotation(): irreps = cue.Irreps("SO3", "3x0 + 1 + 0 + 4x2 + 4") - alpha = torch.tensor(0.3).cuda() - beta = torch.tensor(0.4).cuda() - gamma = torch.tensor(-0.5).cuda() + alpha = torch.tensor(0.3).to(device) + beta = torch.tensor(0.4).to(device) + gamma = torch.tensor(-0.5).to(device) - rot = cuet.Rotation(irreps, layout=cue.ir_mul).cuda() + rot = cuet.Rotation(irreps, layout=cue.ir_mul).to(device) - x = torch.randn(10, irreps.dim).cuda() + x = torch.randn(10, irreps.dim).to(device) rx = rot(gamma, beta, alpha, x) x_ = rot(-alpha, -beta, -gamma, rx) diff --git a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py index 73c3d40..955ee87 100644 --- a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py @@ -19,6 +19,8 @@ import cuequivariance as cue import cuequivariance_torch as cuet +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + @pytest.mark.parametrize( "dtype, tol", @@ -26,15 +28,15 @@ ) @pytest.mark.parametrize("ell", [1, 2, 3]) def test_spherical_harmonics(ell: int, dtype, tol): - vec = torch.randn(3, dtype=dtype, device="cuda") + vec = torch.randn(3, dtype=dtype, device=device) axis = np.random.randn(3) angle = np.random.rand() scale = 1.3 yl = cuet.spherical_harmonics([ell], vec, False) - R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).cuda() - Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).cuda() + R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).to(device) + Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).to(device) yl1 = cuet.spherical_harmonics([ell], scale * R @ vec, False) yl2 = scale**ell * Rl @ yl @@ -43,7 +45,7 @@ def test_spherical_harmonics(ell: int, dtype, tol): def test_spherical_harmonics_full(): - vec = torch.randn(3, device="cuda") + vec = torch.randn(3, device=device) ls = [0, 1, 2, 3] yl = cuet.spherical_harmonics(ls, vec, False) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 46cf3f1..80a4065 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -22,6 +22,8 @@ import cuequivariance_torch as cuet from cuequivariance.experimental.e3nn import O3_e3nn +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + USE_TF32 = False torch.backends.cuda.matmul.allow_tf32 = USE_TF32 torch.backends.cudnn.allow_tf32 = USE_TF32 @@ -45,12 +47,12 @@ def test_symmetric_contraction(dtype, layout, original_mace, batch): layout_out=layout, dtype=dtype, math_dtype=dtype, - device="cuda", + device=device, original_mace=original_mace, ) - x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda() - indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda() + x = torch.randn((batch, irreps_in.dim), dtype=dtype).to(device) + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).to(device) out = m(x, indices) assert out.shape == (batch, irreps_out.dim) @@ -58,7 +60,7 @@ def test_symmetric_contraction(dtype, layout, original_mace, batch): def from64(shape: tuple[int, ...], data: str) -> torch.Tensor: x = np.frombuffer(base64.b64decode(data), dtype=np.float32).reshape(shape) - return torch.from_numpy(x.copy()).cuda() + return torch.from_numpy(x.copy()).to(device) def test_mace_compatibility(): @@ -76,7 +78,7 @@ def test_mace_compatibility(): irreps_in = cue.Irreps(O3_e3nn, "0e + 1o + 2e") irreps_out = cue.Irreps(O3_e3nn, "0e + 1o") - i = (torch.arange(3) % num_elements).cuda() + i = (torch.arange(3) % num_elements).to(device) x = from64( (3, 36), "mHgaP1zHTz5kdhs/3ygQvwzZf77dhoU8+iP+PzzRRD8L9CY+qi9Fv5aiBz/sGJG/xwaev+5w4b2Mbg8+1jDOP4/cwj9rt/u/FedUP7H6qD4y9LM+i7yvPhcifz8coHE/Vkk1PwK0hb/BNig+GF4gP1FNaD94Uj++d+1qPtkrYD8m8o6/9zK9PihGBz9M6Ne9XgCXP/r6bzxTXJO/glIsQPQlDL/fN5w7VeeKP4iYlD/9Msa/GF/cvg+2gz/oRJ6/0Te4P7g+oz8YQ6g+k0q0vN8WEr41/u0/sa55PmAhvD9FZZw/ICJtvyxFkz+zOAq/8JtNPztZX74E9hK/xCdqv4+0Rz9Ah/g+5vmDv6mLL7+M5DI/xgP3PhWEnj5ZmZ0+DBkXwPa12D1mVPo9rDdWP4DkRD+L85Y9EJ01P+8Hiz6gxSM7/eoPwOQOtr8gjge+NBEYPrmg5L2XpO8/F2tCvjEyWL8gjLw+UOIuP5bhPr9qRvM+ADa5v3rqLLwSr/8+PbZhP4tn675SWVm/SMC1P5h/0r0D8v2/CNS7Pza7SL8PqJG+DsKCOpTKoT+xnLg/", @@ -95,7 +97,7 @@ def test_mace_compatibility(): layout_in=cue.ir_mul, layout_out=cue.mul_ir, original_mace=True, - device="cuda", + device=device, dtype=torch.float32, math_dtype=torch.float64, ) diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index d3628ab..d3e7cdd 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -19,6 +19,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + list_of_irreps = [ cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), @@ -40,54 +42,49 @@ def test_channel_wise( use_fallback: bool, batch: int, ): - m = cuet.ChannelWiseTensorProduct( - irreps1, - irreps2, - irreps3, - shared_weights=True, - internal_weights=True, - layout=layout, - device="cuda", - dtype=torch.float64, - ) - m_fx = cuet.ChannelWiseTensorProduct( + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m1 = cuet.ChannelWiseTensorProduct( irreps1, irreps2, irreps3, shared_weights=True, internal_weights=True, layout=layout, - device="cuda", + device=device, dtype=torch.float64, - use_fallback=True, + use_fallback=use_fallback, ) + x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).to(device) + x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).to(device) - x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() - - out1 = m(x1, x2) + out1 = m1(x1, x2) d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d d = d.squeeze_modes("v") assert d.subscripts == "u,iu,j,ku+ijk" if layout == cue.mul_ir: d = d.add_or_transpose_modes("u,ui,j,uk+ijk") - mfx = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).cuda() - out2 = mfx([m.weight, x1, x2]) + m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) + out2 = m2([m1.weight, x1, x2]) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) def test_channel_wise_bwd_bwd(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + irreps1 = cue.Irreps("SO3", "2x0 + 3x1") irreps2 = cue.Irreps("SO3", "0 + 1") irreps3 = cue.Irreps("SO3", "0 + 1") x1 = torch.randn( - 32, irreps1.dim, device="cuda", requires_grad=True, dtype=torch.float64 + 32, irreps1.dim, device=device, requires_grad=True, dtype=torch.float64 ) x2 = torch.randn( - 32, irreps2.dim, device="cuda", requires_grad=True, dtype=torch.float64 + 32, irreps2.dim, device=device, requires_grad=True, dtype=torch.float64 ) outputs = {} diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index d9b19b4..832904b 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -19,6 +19,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + list_of_irreps = [ cue.Irreps("O3", "4x0e + 4x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), @@ -38,43 +40,49 @@ def test_fully_connected( layout: cue.IrrepsLayout, use_fallback: bool, ): - m = cuet.FullyConnectedTensorProduct( + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m1 = cuet.FullyConnectedTensorProduct( irreps1, irreps2, irreps3, shared_weights=True, internal_weights=True, layout=layout, - device="cuda", + device=device, dtype=torch.float64, use_fallback=use_fallback, ) - x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() + x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).to(device) + x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).to(device) - out1 = m(x1, x2) + out1 = m1(x1, x2) d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d if layout == cue.mul_ir: d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") - mfx = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).cuda() - out2 = mfx( - [m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], + m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) + out2 = m2( + [m1.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], ).to(out1.dtype) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) -def test_compile(): - device = "cuda" +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_compile(use_fallback: bool): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + m = cuet.FullyConnectedTensorProduct( irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"), irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"), irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), layout=cue.mul_ir, device=device, - use_fallback=False, + use_fallback=use_fallback, ) m_compile = torch.compile(m, fullgraph=True) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 01ca6d2..68531a7 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -17,17 +17,18 @@ import pytest import torch import torch._dynamo +from utils import ( + module_with_mode, +) import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors +torch._dynamo.config.cache_size_limit = 100 -from utils import ( - module_with_mode, -) +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -torch._dynamo.config.cache_size_limit = 100 def make_descriptors(): # This ETP will trigger the fusedTP kernel @@ -61,7 +62,7 @@ def make_descriptors(): (torch.float32, torch.float32), (torch.float64, torch.float64), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings1 += [ (torch.float16, torch.float32), (torch.bfloat16, torch.float32), @@ -75,7 +76,8 @@ def test_performance_cuda_vs_fx( dtype: torch.dtype, math_dtype: torch.dtype, ): - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.EquivariantTensorProduct( e, @@ -125,7 +127,7 @@ def f1(): (torch.float64, torch.float32, 1e-5, 1e-6), (torch.float64, torch.float64, 1e-12, 0), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings2 += [ (torch.float16, torch.float32, 1, 0.2), (torch.bfloat16, torch.float32, 1, 0.2), @@ -141,7 +143,8 @@ def test_precision_cuda_vs_fx( atol: float, rtol: float, ): - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) @@ -175,10 +178,11 @@ def test_compile( atol: float, rtol: float, ): - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device="cuda", math_dtype=math_dtype + e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) @@ -199,10 +203,11 @@ def test_script( atol: float, rtol: float, ): - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, use_fallback=False, device="cuda", math_dtype=math_dtype + e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) @@ -213,13 +218,14 @@ def test_script( res_script = m_script(inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + # export_modes = ["onnx", "onnx_dynamo", "trt", "torch_trt", "jit"] -export_modes = ["trt","onnx"] +export_modes = ["trt", "onnx"] + @pytest.mark.parametrize("e", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) @pytest.mark.parametrize("mode", export_modes) - def test_export( e: cue.EquivariantTensorProduct, dtype: torch.dtype, @@ -227,13 +233,13 @@ def test_export( atol: float, rtol: float, mode: str, - tmp_path + tmp_path, ): - - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.EquivariantTensorProduct( - e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device="cuda" + e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device=device ) inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) diff --git a/cuequivariance_torch/tests/primitives/script_test.py b/cuequivariance_torch/tests/primitives/script_test.py index 37b2a0c..4706bff 100644 --- a/cuequivariance_torch/tests/primitives/script_test.py +++ b/cuequivariance_torch/tests/primitives/script_test.py @@ -1,3 +1,4 @@ +import pytest import torch import cuequivariance as cue @@ -11,26 +12,32 @@ TensorProductUniform4x1d, ) +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def test_script_symmetric_contraction(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + ds = cue.descriptors.symmetric_contraction( 32 * cue.Irreps("SO3", "0 + 1"), 32 * cue.Irreps("SO3", "0 + 1"), [1, 2, 3] ).ds batch = 12 - x0 = torch.randn(3, ds[0].operands[0].size, device="cuda:0", dtype=torch.float32) - i0 = torch.zeros(batch, device="cuda:0", dtype=torch.int32) - x1 = torch.randn( - batch, ds[0].operands[1].size, device="cuda:0", dtype=torch.float32 - ) + x0 = torch.randn(3, ds[0].operands[0].size, device=device, dtype=torch.float32) + i0 = torch.zeros(batch, device=device, dtype=torch.int32) + x1 = torch.randn(batch, ds[0].operands[1].size, device=device, dtype=torch.float32) - module = SymmetricTensorProduct(ds, torch.device("cuda:0"), torch.float32) + module = SymmetricTensorProduct(ds, device, torch.float32) module = torch.jit.script(module) assert module(x0, i0, x1).shape == (batch, ds[0].operands[-1].size) def test_script_fused_tp_3(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + d = ( cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") @@ -40,16 +47,19 @@ def test_script_fused_tp_3(): ) batch = 12 - x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) - x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) - module = FusedTensorProductOp3(d, (0, 1), torch.device("cuda:0"), torch.float32) + module = FusedTensorProductOp3(d, (0, 1), device, torch.float32) module = torch.jit.script(module) assert module([x0, x1]).shape == (batch, d.operands[2].size) def test_script_fused_tp_4(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + d = ( cue.descriptors.fully_connected_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") @@ -60,17 +70,20 @@ def test_script_fused_tp_4(): ) batch = 12 - x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) - x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) - x2 = torch.randn(batch, d.operands[2].size, device="cuda:0", dtype=torch.float32) + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + x2 = torch.randn(batch, d.operands[2].size, device=device, dtype=torch.float32) - module = FusedTensorProductOp4(d, (0, 1, 2), torch.device("cuda:0"), torch.float32) + module = FusedTensorProductOp4(d, (0, 1, 2), device, torch.float32) module = torch.jit.script(module) assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) def test_script_uniform_tp_3(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + d = ( cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") @@ -80,16 +93,19 @@ def test_script_uniform_tp_3(): ) batch = 12 - x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) - x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) - module = TensorProductUniform3x1d(d, torch.device("cuda:0"), torch.float32) + module = TensorProductUniform3x1d(d, device, torch.float32) module = torch.jit.script(module) assert module([x0, x1]).shape == (batch, d.operands[2].size) def test_script_uniform_tp_4(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + d = ( cue.descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") @@ -99,11 +115,11 @@ def test_script_uniform_tp_4(): ) batch = 12 - x0 = torch.randn(batch, d.operands[0].size, device="cuda:0", dtype=torch.float32) - x1 = torch.randn(batch, d.operands[1].size, device="cuda:0", dtype=torch.float32) - x2 = torch.randn(batch, d.operands[2].size, device="cuda:0", dtype=torch.float32) + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + x2 = torch.randn(batch, d.operands[2].size, device=device, dtype=torch.float32) - module = TensorProductUniform4x1d(d, torch.device("cuda:0"), torch.float32) + module = TensorProductUniform4x1d(d, device, torch.float32) module = torch.jit.script(module) assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index f5ab6aa..7858576 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -20,6 +20,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def make_descriptors(): yield descriptors.symmetric_contraction( @@ -47,7 +49,7 @@ def make_descriptors(): (torch.float32, torch.float64, 1e-5), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings1 += [ (torch.float16, torch.float32, 1.0), (torch.float16, torch.float64, 0.1), @@ -58,15 +60,22 @@ def make_descriptors(): @pytest.mark.parametrize("ds", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings1) +@pytest.mark.parametrize("use_fallback", [False, True]) def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( - ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float + ds: list[stp.SegmentedTensorProduct], + dtype, + math_dtype, + tol: float, + use_fallback: bool, ): - device = torch.device("cuda:0") + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=math_dtype, device=device, use_fallback=False + ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) - m = torch.jit.script(m) + if use_fallback is False: + m = torch.jit.script(m) x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True) i0 = torch.tensor([0, 1, 0], dtype=torch.int32, device=device) @@ -109,7 +118,7 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( (torch.float32, torch.float64), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings2 += [ (torch.float16, torch.float32), (torch.bfloat16, torch.float32), @@ -117,17 +126,16 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( @pytest.mark.parametrize("dtype, math_dtype", settings2) -def test_math_dtype( - dtype: torch.dtype, - math_dtype: torch.dtype, -): - device = torch.device("cuda:0") +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: bool): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") ds = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ).ds m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=math_dtype, device=device, use_fallback=False + ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) x0 = torch.randn((20, m.x0_size), dtype=dtype, device=device) i0 = torch.randint(0, m.x0_size, (1000,), dtype=torch.int32, device=device) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index e4ba7cc..d8c26ef 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -22,6 +22,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def make_descriptors(): yield descriptors.fully_connected_tensor_product( @@ -80,7 +82,7 @@ def make_descriptors(): (torch.float64, torch.float64, 1e-12), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings += [ (torch.float16, torch.float32, 1.0), (torch.bfloat16, torch.float32, 1.0), @@ -89,13 +91,16 @@ def make_descriptors(): @pytest.mark.parametrize("d", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings) +@pytest.mark.parametrize("use_fallback", [False, True]) def test_primitive_tensor_product_cuda_vs_fx( d: stp.SegmentedTensorProduct, dtype: torch.dtype, math_dtype: torch.dtype, tol: float, + use_fallback: bool, ): - device = torch.device("cuda:0") + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") for batches in itertools.product([(16,), (), (4, 1)], repeat=d.num_operands - 1): inputs = [ @@ -109,9 +114,11 @@ def test_primitive_tensor_product_cuda_vs_fx( ] m = cuet.TensorProduct( - d, device=device, math_dtype=math_dtype, optimize_fallback=False + d, device=device, math_dtype=math_dtype, use_fallback=use_fallback ) - m = torch.jit.script(m) + if not use_fallback: + m = torch.jit.script(m) + out1 = m(inputs) m = cuet.TensorProduct( diff --git a/cuequivariance_torch/tests/primitives/transpose_test.py b/cuequivariance_torch/tests/primitives/transpose_test.py index 67ad700..31eb271 100644 --- a/cuequivariance_torch/tests/primitives/transpose_test.py +++ b/cuequivariance_torch/tests/primitives/transpose_test.py @@ -12,13 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch + import cuequivariance_torch as cuet -import pytest +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") dtypes = [torch.float32, torch.float64] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: dtypes += [torch.float16, torch.bfloat16] @@ -33,14 +35,16 @@ def test_transpose(use_fallback: bool, dtype: torch.dtype): 10 11 10 12 12 13 11 13 """ + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") x = torch.tensor( - [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10, 11, 12, 13]], dtype=dtype - ).cuda() + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10, 11, 12, 13]], dtype=dtype, device=device + ) segments = [(2, 3), (2, 2)] xt = torch.tensor( - [[1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 10, 12, 11, 13]], dtype=dtype - ).cuda() + [[1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 10, 12, 11, 13]], dtype=dtype, device=device + ) - m = cuet.TransposeSegments(segments, use_fallback=use_fallback).cuda() + m = cuet.TransposeSegments(segments, device, use_fallback=use_fallback) torch.testing.assert_close(m(x), xt) From 8afa05674733a3a56e1e5b6a6f3c7c102077392e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 13:51:01 +0100 Subject: [PATCH 36/44] fix tests --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index dadac6c..828840a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,8 +56,8 @@ jobs: pytest --doctest-modules cuequivariance_jax cuequivariance-torch: - - runs-on: self-hosted + # runs-on: self-hosted (temporary unavailable) + runs-on: ubuntu-latest strategy: fail-fast: false matrix: From 09bbc8d95d08c88723500221ab29c8d2c1e30055 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 13:53:22 +0100 Subject: [PATCH 37/44] fix ruff --- .../cuequivariance_torch/primitives/symmetric_tensor_product.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 69f4a7c..dc61230 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math from typing import Optional import torch From 9c38168cbc5f854ecaa579319b8825e395fdb58a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 13:55:24 +0100 Subject: [PATCH 38/44] fix --- .../cuequivariance_torch/layers/tp_conv_fully_connected.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index dd1a255..b6a6fb0 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -70,7 +70,7 @@ class FullyConnectedTensorProductConv(nn.Module): having 16 channels. edge_emb.size(1) must match the size of the input layer: 6 >>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, - ... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul).cuda() + ... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul) >>> conv1 FullyConnectedTensorProductConv(...) >>> # out = conv1(src_features, edge_sh, edge_emb, graph) @@ -92,7 +92,7 @@ class FullyConnectedTensorProductConv(nn.Module): **Case 3**: No MLP, edge_emb will be directly used as the tensor product weights: >>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, - ... mlp_channels=None, layout=cue.ir_mul).cuda() + ... mlp_channels=None, layout=cue.ir_mul) >>> # out = conv3(src_features, edge_sh, edge_emb, graph) """ From de9af8f48c42ed7116b73612fbb507db17d89730 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 13:57:13 +0100 Subject: [PATCH 39/44] fix docstring tests --- .../operations/symmetric_contraction.py | 7 ++++--- .../primitives/equivariant_tensor_product.py | 11 ++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index e1b8d66..a91a72a 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -52,6 +52,7 @@ class SymmetricContraction(torch.nn.Module): The argument `original_mace` can be set to `True` to emulate the original MACE implementation. + >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e") >>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o") >>> # OLD FUNCTION DEFINITION: @@ -71,15 +72,15 @@ class SymmetricContraction(torch.nn.Module): ... layout_out=cue.mul_ir, ... original_mace=True, ... dtype=torch.float64, - ... device=torch.device("cuda"), + ... device=device, ... ) Then the execution is as follows: - >>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64).cuda() + >>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device) >>> # with node_attrs_index being the index version of node_attrs, sth like: >>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int() - >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32).cuda() + >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device) >>> # OLD CALL: >>> # symmetric_contractions_old(node_feats, node_attrs) >>> # NEW CALL: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 3bd3b63..16701ff 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -124,19 +124,20 @@ class EquivariantTensorProduct(torch.nn.Module): RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. Examples: + >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) - >>> w = torch.ones(e.inputs[0].irreps.dim).cuda() - >>> x1 = torch.ones(17, e.inputs[1].irreps.dim).cuda() - >>> x2 = torch.ones(17, e.inputs[2].irreps.dim).cuda() - >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=torch.device("cuda")) + >>> w = torch.ones(e.inputs[0].irreps.dim, device=device) + >>> x1 = torch.ones(17, e.inputs[1].irreps.dim, device=device) + >>> x2 = torch.ones(17, e.inputs[2].irreps.dim, device=device) + >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) >>> tp([w, x1, x2]) tensor([[0., 0., 0., 0., 0., 0.],...) You can optionally index the first input tensor: - >>> w = torch.ones(3, e.inputs[0].irreps.dim).cuda() + >>> w = torch.ones(3, e.inputs[0].irreps.dim, device=device) >>> indices = torch.randint(3, (17,)) >>> tp([w, x1, x2], indices=indices) tensor([[0., 0., 0., 0., 0., 0.],...) From 999a31defde6112659f87fe24484328cbb26eb26 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 14:00:11 +0100 Subject: [PATCH 40/44] add -x to tests --- .github/workflows/tests.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 828840a..def808e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,7 @@ jobs: python -m pip install ./cuequivariance - name: Test with pytest run: | - pytest --doctest-modules cuequivariance + pytest --doctest-modules -x cuequivariance cuequivariance-jax: @@ -53,7 +53,7 @@ jobs: python -m pip install ./cuequivariance_jax - name: Test with pytest run: | - pytest --doctest-modules cuequivariance_jax + pytest --doctest-modules -x cuequivariance_jax cuequivariance-torch: # runs-on: self-hosted (temporary unavailable) @@ -79,4 +79,4 @@ jobs: python -m pip install e3nn - name: Test with pytest run: | - pytest --doctest-modules cuequivariance_torch + pytest --doctest-modules -x cuequivariance_torch From 905e716102d2b8737914c33012f178db077dbb73 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 14:02:15 +0100 Subject: [PATCH 41/44] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 867796d..3ab694f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Latest Changes +### Added + +- Partial support of `torch.jit.script` and `torch.compile` + ### Changed - `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. From 975e9c87d3107d64c23ab43d63682fd04937053e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 15:36:25 +0100 Subject: [PATCH 42/44] test --- .github/workflows/tests.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index def808e..19ae3ab 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,12 +56,13 @@ jobs: pytest --doctest-modules -x cuequivariance_jax cuequivariance-torch: - # runs-on: self-hosted (temporary unavailable) - runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.12"] + runner: ["ubuntu-latest", "self-hosted"] + + runs-on: ${{ matrix.runner }} steps: - uses: actions/checkout@v4 @@ -72,11 +73,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - python -m pip install torch + python -m pip install pytest torch e3nn python -m pip install ./cuequivariance python -m pip install ./cuequivariance_torch - python -m pip install e3nn - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_torch From 093e8e40101bd23e61e61d1c8838fedac365c561 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 06:49:12 -0800 Subject: [PATCH 43/44] move utils into test file --- .../equivariant_tensor_product_test.py | 193 ++++++++++++- .../tests/primitives/utils.py | 267 ------------------ 2 files changed, 190 insertions(+), 270 deletions(-) delete mode 100644 cuequivariance_torch/tests/primitives/utils.py diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 68531a7..09759c1 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -12,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import timeit +from typing import Sequence import pytest import torch import torch._dynamo -from utils import ( - module_with_mode, -) import cuequivariance as cue import cuequivariance_torch as cuet @@ -27,6 +26,194 @@ torch._dynamo.config.cache_size_limit = 100 + +try: + import cuequivariance_ops_torch.onnx # noqa: F401 + import onnx # noqa: F401 + import onnxruntime # noqa: F401 + import onnxscript # noqa: F401 + from cuequivariance_ops_torch.tensorrt import register_plugins + + ONNX_AVAILABLE = True +except Exception: + ONNX_AVAILABLE = False + + +try: + import torch_tensorrt + + TORCH_TRT_AVAILABLE = True +except Exception: + TORCH_TRT_AVAILABLE = False + + +def verify_onnx(module, onnx_module, inputs, dtype): + if dtype != torch.float32: + pytest.skip("onnxrt only checked for float32") + from onnxruntime import SessionOptions + from onnxruntime_extensions import get_library_path + from torch.onnx.verification import ( + VerificationOptions, + _compare_onnx_pytorch_model, + ) + + original_init = SessionOptions.__init__ + + def new_init(self): + original_init(self) + try: + self.register_custom_ops_library(get_library_path()) + except Exception: + pass + + SessionOptions.__init__ = new_init + _compare_onnx_pytorch_model( + module, onnx_module, tuple(inputs), None, None, VerificationOptions() + ) + SessionOptions.__init__ = original_init + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def verify_trt(module, onnx_module, inputs, dtype): + import tensorrt + from pkg_resources import parse_version + + if parse_version(tensorrt.__version__) < parse_version("10.3.0"): + pytest.skip("TRT < 10.3.0 is not supported!") + if dtype == torch.float64: + pytest.skip("TRT does not support float64") + + from onnxruntime import InferenceSession, SessionOptions + from onnxruntime_extensions import get_library_path + from polygraphy.backend.onnxrt import OnnxrtRunner + from polygraphy.backend.trt import ( + CreateConfig, + TrtRunner, + engine_from_network, + network_from_onnx_path, + ) + from polygraphy.comparator import Comparator, DataLoader + + register_plugins() + + network = network_from_onnx_path(onnx_module) + trt_engine = engine_from_network(network, config=CreateConfig()) + + if dtype != torch.float32: + pytest.skip("Comparator only supports float32") + + # Create runners for ONNX and TRT models + trt_runner = TrtRunner(trt_engine) + + options = SessionOptions() + options.register_custom_ops_library(get_library_path()) + onnx_runner = OnnxrtRunner(InferenceSession(onnx_module, sess_options=options)) + + results = Comparator.run([trt_runner, onnx_runner], data_loader=DataLoader()) + Comparator.compare_accuracy(results) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def module_with_mode( + mode, + module, + inputs, + math_dtype, + tmp_path, + grad_modes=["eager", "compile", "jit", "export"], +): + if isinstance(inputs[0], list): + dtype = inputs[0][0].dtype + else: + dtype = inputs[0].dtype + if mode in ["trt", "torch_trt", "onnx", "onnx_dynamo", "export"]: + if not ONNX_AVAILABLE: + pytest.skip("ONNX not available!") + if dtype == torch.float64 or math_dtype == torch.float64: + pytest.skip("TRT/ORT do not support float64") + + with torch.set_grad_enabled(mode in grad_modes): + if mode == "compile": + import sys + + if sys.version_info.major == 3 and sys.version_info.minor >= 12: + pytest.skip("torch dynamo needs cpy <= 3.11") + module = torch.compile(module) + elif mode == "fx": + module = torch.fx.symbolic_trace(module) + elif mode == "jit": + module = torch.jit.trace(module, inputs) + fname = os.path.join(tmp_path, "test.ts") + torch.jit.save(module, fname) + module = torch.jit.load(fname) + elif mode == "export": + exp_program = torch.export.export(module, tuple(inputs)) + fname = os.path.join(tmp_path, "test.pt2") + torch.export.save(exp_program, fname) + del exp_program + module = torch.export.load(fname).module() + elif mode == "torch_trt": + if not TORCH_TRT_AVAILABLE: + pytest.skip("torch_tensorrt is not installed!") + register_plugins() + exp_program = torch_tensorrt.dynamo.trace(module, inputs) + module = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + require_full_compilation=True, + min_block_size=1, + enabled_precisions={torch.float32, dtype}, + # dryrun=True + ) + elif mode == "onnx" or mode == "trt": + try: + onnx_path = os.path.join(tmp_path, "test.onnx") + torch.onnx.export( + module, tuple(inputs), onnx_path, opset_version=17, verbose=False + ) + if mode == "trt": + verify_trt(module, onnx_path, inputs, dtype) + else: + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX/TRT is not available") + + elif mode == "onnx_dynamo": + try: + from cuequivariance_ops_torch.onnx import ( + cuequivariance_ops_torch_onnx_registry, + ) + + export_options = torch.onnx.ExportOptions( + onnx_registry=cuequivariance_ops_torch_onnx_registry + ) + onnx_program = torch.onnx.dynamo_export( + module, *inputs, export_options=export_options + ) + onnx_path = os.path.join(tmp_path, "test.onnx") + onnx_program.save(onnx_path) + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX is not available") + elif mode == "eager": + pass + else: + raise ValueError(f"No such mode: {mode}") + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + return module + + +def maybe_detach_and_to(tensor, *args, **kwargs): + if tensor is not None: + return tensor.clone().detach().to(*args, **kwargs) + return None + + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") diff --git a/cuequivariance_torch/tests/primitives/utils.py b/cuequivariance_torch/tests/primitives/utils.py deleted file mode 100644 index ce646fd..0000000 --- a/cuequivariance_torch/tests/primitives/utils.py +++ /dev/null @@ -1,267 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. - -import os -import pytest -import torch -from typing import Sequence - -try: - import onnx # noqa: F401 - import onnxscript # noqa: F401 - import onnxruntime # noqa: F401 - import cuequivariance_ops_torch.onnx # noqa: F401 - from cuequivariance_ops_torch.tensorrt import register_plugins - - ONNX_AVAILABLE = True -except Exception: - ONNX_AVAILABLE = False - - -try: - import torch_tensorrt - - TORCH_TRT_AVAILABLE = True -except Exception: - TORCH_TRT_AVAILABLE = False - - -def verify_onnx(module, onnx_module, inputs, dtype): - if dtype != torch.float32: - pytest.skip("onnxrt only checked for float32") - from onnxruntime import SessionOptions - from onnxruntime_extensions import get_library_path - from torch.onnx.verification import ( - _compare_onnx_pytorch_model, - VerificationOptions, - ) - - original_init = SessionOptions.__init__ - - def new_init(self): - original_init(self) - try: - self.register_custom_ops_library(get_library_path()) - except Exception: - pass - - SessionOptions.__init__ = new_init - _compare_onnx_pytorch_model( - module, onnx_module, tuple(inputs), None, None, VerificationOptions() - ) - SessionOptions.__init__ = original_init - torch.cuda.synchronize() - torch.cuda.empty_cache() - - -def verify_trt(module, onnx_module, inputs, dtype): - import tensorrt - from pkg_resources import parse_version - - if parse_version(tensorrt.__version__) < parse_version("10.3.0"): - pytest.skip("TRT < 10.3.0 is not supported!") - if dtype == torch.float64: - pytest.skip("TRT does not support float64") - - from polygraphy.backend.trt import ( - engine_from_network, - network_from_onnx_path, - TrtRunner, - CreateConfig, - ) - from polygraphy.backend.onnxrt import OnnxrtRunner - from polygraphy.comparator import Comparator, DataLoader - from onnxruntime import InferenceSession, SessionOptions - from onnxruntime_extensions import get_library_path - - register_plugins() - - network = network_from_onnx_path(onnx_module) - trt_engine = engine_from_network(network, config=CreateConfig()) - - if dtype != torch.float32: - pytest.skip("Comparator only supports float32") - - # Create runners for ONNX and TRT models - trt_runner = TrtRunner(trt_engine) - - options = SessionOptions() - options.register_custom_ops_library(get_library_path()) - onnx_runner = OnnxrtRunner(InferenceSession(onnx_module, sess_options=options)) - - results = Comparator.run([trt_runner, onnx_runner], data_loader=DataLoader()) - Comparator.compare_accuracy(results) - torch.cuda.synchronize() - torch.cuda.empty_cache() - - -def module_with_mode( - mode, - module, - inputs, - math_dtype, - tmp_path, - grad_modes=["eager", "compile", "jit", "export"], -): - if isinstance(inputs[0], list): - dtype = inputs[0][0].dtype - else: - dtype = inputs[0].dtype - if mode in ["trt", "torch_trt", "onnx", "onnx_dynamo", "export"]: - if not ONNX_AVAILABLE: - pytest.skip("ONNX not available!") - if dtype == torch.float64 or math_dtype == torch.float64: - pytest.skip("TRT/ORT do not support float64") - - with torch.set_grad_enabled(mode in grad_modes): - if mode == "compile": - import sys - - if sys.version_info.major == 3 and sys.version_info.minor >= 12: - pytest.skip("torch dynamo needs cpy <= 3.11") - module = torch.compile(module) - elif mode == "fx": - module = torch.fx.symbolic_trace(module) - elif mode == "jit": - module = torch.jit.trace(module, inputs) - fname = os.path.join(tmp_path, "test.ts") - torch.jit.save(module, fname) - module = torch.jit.load(fname) - elif mode == "export": - exp_program = torch.export.export(module, tuple(inputs)) - fname = os.path.join(tmp_path, "test.pt2") - torch.export.save(exp_program, fname) - del exp_program - module = torch.export.load(fname).module() - elif mode == "torch_trt": - if not TORCH_TRT_AVAILABLE: - pytest.skip("torch_tensorrt is not installed!") - register_plugins() - exp_program = torch_tensorrt.dynamo.trace(module, inputs) - module = torch_tensorrt.dynamo.compile( - exp_program, - inputs=inputs, - require_full_compilation=True, - min_block_size=1, - enabled_precisions={torch.float32, dtype}, - # dryrun=True - ) - elif mode == "onnx" or mode == "trt": - try: - onnx_path = os.path.join(tmp_path, "test.onnx") - torch.onnx.export( - module, tuple(inputs), onnx_path, opset_version=17, verbose=False - ) - if mode == "trt": - verify_trt(module, onnx_path, inputs, dtype) - else: - verify_onnx(module, onnx_path, inputs, dtype) - except ImportError: - pytest.skip("ONNX/TRT is not available") - - elif mode == "onnx_dynamo": - try: - from cuequivariance_ops_torch.onnx import ( - cuequivariance_ops_torch_onnx_registry, - ) - - export_options = torch.onnx.ExportOptions( - onnx_registry=cuequivariance_ops_torch_onnx_registry - ) - onnx_program = torch.onnx.dynamo_export( - module, *inputs, export_options=export_options - ) - onnx_path = os.path.join(tmp_path, "test.onnx") - onnx_program.save(onnx_path) - verify_onnx(module, onnx_path, inputs, dtype) - except ImportError: - pytest.skip("ONNX is not available") - elif mode == "eager": - pass - else: - raise ValueError(f"No such mode: {mode}") - - torch.cuda.synchronize() - torch.cuda.empty_cache() - - return module - - -def create_random_tensor_2d(batch_size, stride, requires_grad, dtype, is_shared): - data = torch.randn( - (stride,) if is_shared else (batch_size, stride), - dtype=dtype, - device="cuda", - ).requires_grad_(requires_grad) - - return data - - -def maybe_detach_and_to(tensor, *args, **kwargs): - if tensor is not None: - return tensor.clone().detach().to(*args, **kwargs) - return None - - -def run_fwd_test(module, x: Sequence): - with torch.no_grad(): - out = module(*x) - test_output = [maybe_detach_and_to(out, dtype=torch.float32)] - return test_output - - -def run_fwd_bwd_test(module, x: Sequence): - out = module(*x) - - loss = out.sum() - loss.backward() - - test_output = [maybe_detach_and_to(out, dtype=torch.float32)] - test_output.extend([maybe_detach_and_to(t.grad, dtype=torch.float32) for t in x]) - - return test_output - - -def run_bwd_bwd_test(module, x: Sequence): - test_outputs = [] - out = module(*x) - grads = torch.autograd.grad(out.pow(2).sum(), x, create_graph=True) - test_outputs.extend([maybe_detach_and_to(g, dtype=torch.float32) for g in grads]) - loss = sum([g.sum() for g in grads]) - loss.backward() - test_outputs.extend([maybe_detach_and_to(t.grad, dtype=torch.float32) for t in x]) - return test_outputs - - -def assert_close_modules(m_test, m_ref, inputs_test, procedure, tol_dict): - outs_test = procedure(m_test, inputs_test) - - inputs_ref = [ - x.clone() - .detach() - .to(device="cuda", dtype=torch.float32) - .requires_grad_(x.requires_grad) - for x in inputs_test - ] - outs_ref = procedure(m_ref, inputs_ref) - for out_test, out_ref in zip(outs_test, outs_ref): - torch.testing.assert_close(out_test, out_ref, **tol_dict) - - -tol_dict = { - # we compare against double for precision reasons - # hence FP64 and FP32 threshold are the same - (torch.float64, torch.float64): {"atol": 1e-9, "rtol": 1e-5}, - (torch.float32, torch.float64): {"atol": 1e-4, "rtol": 1e-5}, - (torch.float64, torch.float32): {"atol": 1e-4, "rtol": 1e-5}, - (torch.float32, torch.float32): {"atol": 1e-4, "rtol": 1e-5}, - (torch.bfloat16, torch.float32): {"atol": 4.0, "rtol": 1e-2}, - (torch.float16, torch.float32): {"atol": 0.25, "rtol": 1e-2}, -} From 2712f541f826975e16e4e5197a5bca0f1015c9ad Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 9 Dec 2024 06:51:52 -0800 Subject: [PATCH 44/44] fix --- .../tests/primitives/equivariant_tensor_product_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 09759c1..043e387 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -14,7 +14,6 @@ # limitations under the License. import os import timeit -from typing import Sequence import pytest import torch