From 5b6da8589da50556ccb839819ad8ec3321237d0e Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 13 Jan 2025 02:31:25 -0800 Subject: [PATCH] fix tests --- .../equivariant_tensor_product_test.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 603d258..2575e67 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -17,20 +17,17 @@ import pytest import torch import torch._dynamo -from tests.utils import ( - module_with_mode, -) +from tests.utils import module_with_mode import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance import descriptors device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") def make_descriptors(): # This ETP will trigger the fusedTP kernel - yield descriptors.fully_connected_tensor_product( + yield cue.descriptors.fully_connected_tensor_product( cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "0e + 1o + 2e"), cue.Irreps("O3", "32x0e + 32x1o"), @@ -38,7 +35,7 @@ def make_descriptors(): # This ETP will trigger the uniform1dx4 kernel yield ( - descriptors.channelwise_tensor_product( + cue.descriptors.channelwise_tensor_product( cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "0e + 1o + 2e"), cue.Irreps("O3", "0e + 1o"), @@ -48,8 +45,8 @@ def make_descriptors(): ) # These ETPs will trigger the symmetricContraction kernel - yield descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5]) - yield descriptors.symmetric_contraction( + yield cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) + yield cue.descriptors.symmetric_contraction( cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [0, 1, 2, 3], @@ -184,3 +181,25 @@ def test_export( m = module_with_mode(mode, m, exp_inputs, math_dtype, tmp_path) res_script = m(*inputs) torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("batch_size", [0, 5]) +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_high_degrees(use_fallback: bool, batch_size: int): + if not use_fallback and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5]) + m = cuet.EquivariantTensorProduct( + e, + layout=cue.mul_ir, + device=device, + math_dtype=torch.float32, + use_fallback=use_fallback, + ) + inputs = [ + torch.randn((batch_size, rep.dim), device=device, dtype=torch.float32) + for rep in e.inputs + ] + output = m(*inputs) + assert output.shape == (batch_size, e.output.dim)