-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix tests #62
Merged
Fix tests #62
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,28 +17,25 @@ | |
import pytest | ||
import torch | ||
import torch._dynamo | ||
from tests.utils import ( | ||
module_with_mode, | ||
) | ||
from tests.utils import module_with_mode | ||
|
||
import cuequivariance as cue | ||
import cuequivariance_torch as cuet | ||
from cuequivariance import descriptors | ||
|
||
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
|
||
def make_descriptors(): | ||
# This ETP will trigger the fusedTP kernel | ||
yield descriptors.fully_connected_tensor_product( | ||
yield cue.descriptors.fully_connected_tensor_product( | ||
cue.Irreps("O3", "32x0e + 32x1o"), | ||
cue.Irreps("O3", "0e + 1o + 2e"), | ||
cue.Irreps("O3", "32x0e + 32x1o"), | ||
).flatten_coefficient_modes() | ||
|
||
# This ETP will trigger the uniform1dx4 kernel | ||
yield ( | ||
descriptors.channelwise_tensor_product( | ||
cue.descriptors.channelwise_tensor_product( | ||
cue.Irreps("O3", "32x0e + 32x1o"), | ||
cue.Irreps("O3", "0e + 1o + 2e"), | ||
cue.Irreps("O3", "0e + 1o"), | ||
|
@@ -48,8 +45,8 @@ def make_descriptors(): | |
) | ||
|
||
# These ETPs will trigger the symmetricContraction kernel | ||
yield descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5]) | ||
yield descriptors.symmetric_contraction( | ||
yield cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) | ||
yield cue.descriptors.symmetric_contraction( | ||
cue.Irreps("O3", "32x0e + 32x1o"), | ||
cue.Irreps("O3", "32x0e + 32x1o"), | ||
[0, 1, 2, 3], | ||
|
@@ -184,3 +181,25 @@ def test_export( | |
m = module_with_mode(mode, m, exp_inputs, math_dtype, tmp_path) | ||
res_script = m(*inputs) | ||
torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [0, 5]) | ||
@pytest.mark.parametrize("use_fallback", [True, False]) | ||
def test_high_degrees(use_fallback: bool, batch_size: int): | ||
if not use_fallback and not torch.cuda.is_available(): | ||
pytest.skip("CUDA is not available") | ||
|
||
e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5]) | ||
m = cuet.EquivariantTensorProduct( | ||
e, | ||
layout=cue.mul_ir, | ||
device=device, | ||
math_dtype=torch.float32, | ||
use_fallback=use_fallback, | ||
) | ||
inputs = [ | ||
torch.randn((batch_size, rep.dim), device=device, dtype=torch.float32) | ||
for rep in e.inputs | ||
] | ||
output = m(*inputs) | ||
assert output.shape == (batch_size, e.output.dim) | ||
Comment on lines
+186
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I separately test if a high degree polynomial goes through. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I put back a lower degree polynomial.
The degree 5 was giving a bit too high errors compared to all the other ETP tested.