-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix tests #62
Conversation
@@ -48,8 +45,8 @@ def make_descriptors(): | |||
) | |||
|
|||
# These ETPs will trigger the symmetricContraction kernel | |||
yield descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I put back a lower degree polynomial.
The degree 5 was giving a bit too high errors compared to all the other ETP tested.
@pytest.mark.parametrize("batch_size", [0, 5]) | ||
@pytest.mark.parametrize("use_fallback", [True, False]) | ||
def test_high_degrees(use_fallback: bool, batch_size: int): | ||
if not use_fallback and not torch.cuda.is_available(): | ||
pytest.skip("CUDA is not available") | ||
|
||
e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3, 4, 5]) | ||
m = cuet.EquivariantTensorProduct( | ||
e, | ||
layout=cue.mul_ir, | ||
device=device, | ||
math_dtype=torch.float32, | ||
use_fallback=use_fallback, | ||
) | ||
inputs = [ | ||
torch.randn((batch_size, rep.dim), device=device, dtype=torch.float32) | ||
for rep in e.inputs | ||
] | ||
output = m(*inputs) | ||
assert output.shape == (batch_size, e.output.dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I separately test if a high degree polynomial goes through.
No description provided.