Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix tests #62

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,25 @@
import pytest
import torch
import torch._dynamo
from tests.utils import (
module_with_mode,
)
from tests.utils import module_with_mode

import cuequivariance as cue
import cuequivariance_torch as cuet
from cuequivariance import descriptors

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def make_descriptors():
# This ETP will trigger the fusedTP kernel
yield descriptors.fully_connected_tensor_product(
yield cue.descriptors.fully_connected_tensor_product(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "0e + 1o + 2e"),
cue.Irreps("O3", "32x0e + 32x1o"),
).flatten_coefficient_modes()

# This ETP will trigger the uniform1dx4 kernel
yield (
descriptors.channelwise_tensor_product(
cue.descriptors.channelwise_tensor_product(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "0e + 1o + 2e"),
cue.Irreps("O3", "0e + 1o"),
Expand All @@ -48,8 +45,8 @@ def make_descriptors():
)

# These ETPs will trigger the symmetricContraction kernel
yield descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I put back a lower degree polynomial.
The degree 5 was giving a bit too high errors compared to all the other ETP tested.

yield descriptors.symmetric_contraction(
yield cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])
yield cue.descriptors.symmetric_contraction(
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "32x0e + 32x1o"),
[0, 1, 2, 3],
Expand Down Expand Up @@ -184,3 +181,25 @@ def test_export(
m = module_with_mode(mode, m, exp_inputs, math_dtype, tmp_path)
res_script = m(*inputs)
torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol)


@pytest.mark.parametrize("batch_size", [0, 5])
@pytest.mark.parametrize("use_fallback", [True, False])
def test_high_degrees(use_fallback: bool, batch_size: int):
if not use_fallback and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5])
m = cuet.EquivariantTensorProduct(
e,
layout=cue.mul_ir,
device=device,
math_dtype=torch.float32,
use_fallback=use_fallback,
)
inputs = [
torch.randn((batch_size, rep.dim), device=device, dtype=torch.float32)
for rep in e.inputs
]
output = m(*inputs)
assert output.shape == (batch_size, e.output.dim)
Comment on lines +186 to +205
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I separately test if a high degree polynomial goes through.

Loading