Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Compatibility with jit.script and torch.compile: COMPLETE #40

Merged
merged 55 commits into from
Dec 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ba9580a
test and quick fix for zero batch
mariogeiger Nov 20, 2024
0bfada9
trigger uniform 1d in test
mariogeiger Nov 20, 2024
fd097c6
satisfy linter
mariogeiger Nov 21, 2024
251fc4d
from typing import
mariogeiger Nov 21, 2024
3498a32
determine math_dtype earlier
mariogeiger Nov 21, 2024
7f3cf05
warning with pip commands
mariogeiger Nov 21, 2024
2624335
remove unused argument
mariogeiger Nov 21, 2024
91f7fce
changelog
mariogeiger Nov 21, 2024
4401048
list of inputs
mariogeiger Nov 21, 2024
ad2db8d
add Fixed subtite
mariogeiger Nov 21, 2024
dca96a8
Merge branch 'zero-batch' into list-inputs
mariogeiger Nov 21, 2024
889051a
changelog
mariogeiger Nov 21, 2024
c23816a
Merge branch 'main' into list-inputs
mariogeiger Nov 21, 2024
0487d77
Merge branch 'main' into list-inputs
mariogeiger Dec 3, 2024
bc6b405
add test for torch.jit.script
mariogeiger Dec 3, 2024
c8de185
fix
mariogeiger Dec 3, 2024
5e00b37
Merge branch 'list-inputs' into jit-script
mariogeiger Dec 3, 2024
16e4450
remove keyword-only and import in the forward
mariogeiger Dec 3, 2024
e979b0f
Merge branch 'main' into jit-script
mariogeiger Dec 4, 2024
b2c4fbb
low lvl script tests
mariogeiger Dec 4, 2024
4669a86
TensorProduct working with script()
borisfom Dec 4, 2024
dc9d5b0
add 4 operands tests
mariogeiger Dec 4, 2024
334b460
Unit tests run
borisfom Dec 5, 2024
79e7c5f
Restoring debug logging
borisfom Dec 5, 2024
46a0478
Merge branch 'jit-script' into jit-script
borisfom Dec 5, 2024
401fd53
Merge branch 'jit-script' of github.com:NVIDIA/cuEquivariance into ji…
borisfom Dec 5, 2024
8fce54b
Merge remote-tracking branch 'b/jit-script' into jit-script
borisfom Dec 5, 2024
6c5cdb0
Parameterized script test
borisfom Dec 5, 2024
e21c45f
Fixed transpose for script(), script_test successful
borisfom Dec 5, 2024
779dd9c
Fixed input mutation
borisfom Dec 5, 2024
c315857
Fixed tests
borisfom Dec 6, 2024
ab590c8
format with black
mariogeiger Dec 6, 2024
ec1eb27
format with black
mariogeiger Dec 6, 2024
faf235e
fix tests
mariogeiger Dec 6, 2024
c476af9
fix missing parenthesis
mariogeiger Dec 6, 2024
994b8d9
fix tests: increase torch._dynamo.config.cache_size_limit
mariogeiger Dec 6, 2024
f240eb8
fix docstring tests
mariogeiger Dec 6, 2024
fbfb9d0
replace == by is
mariogeiger Dec 6, 2024
dc20be5
clean use_fallback conditions
mariogeiger Dec 6, 2024
4b201c3
fix
mariogeiger Dec 6, 2024
b5b59b8
fix
mariogeiger Dec 6, 2024
72baf17
Export test added, scripting fallback attempt
borisfom Dec 7, 2024
5a94b09
Merge remote-tracking branch 'b/jit-script' into jit-script
borisfom Dec 7, 2024
6bdf924
Merge branch 'main' into jit-script
mariogeiger Dec 9, 2024
8d31929
enable tests on cpu
mariogeiger Dec 9, 2024
8afa056
fix tests
mariogeiger Dec 9, 2024
09bbc8d
fix ruff
mariogeiger Dec 9, 2024
9c38168
fix
mariogeiger Dec 9, 2024
de9af8f
fix docstring tests
mariogeiger Dec 9, 2024
999a31d
add -x to tests
mariogeiger Dec 9, 2024
905e716
changelog
mariogeiger Dec 9, 2024
975e9c8
test
mariogeiger Dec 9, 2024
093e8e4
move utils into test file
mariogeiger Dec 9, 2024
2712f54
fix
mariogeiger Dec 9, 2024
008ee3d
Merge branch 'main' into jit-script
mariogeiger Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
test and quick fix for zero batch
  • Loading branch information
mariogeiger committed Nov 20, 2024
commit ba9580a62ceea0d4310b4ffbef210aaf6e7e9a0b
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import *
from typing import Optional, Union

import torch

@@ -70,9 +70,6 @@ def __init__(
optimize_fallback: Optional[bool] = None,
):
super().__init__()
cue.descriptors.fully_connected_tensor_product(
cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1")
)
if not isinstance(layout_in, tuple):
layout_in = (layout_in,) * e.num_inputs
if len(layout_in) != e.num_inputs:
Original file line number Diff line number Diff line change
@@ -14,15 +14,13 @@
# limitations under the License.
import logging
import math
import warnings
from typing import *
from typing import Optional

import torch
import torch.fx

import cuequivariance.segmented_tensor_product as stp
import cuequivariance_torch as cuet
from cuequivariance import segmented_tensor_product as stp

logger = logging.getLogger(__name__)

@@ -341,7 +339,7 @@ def forward(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
out = self.f(x1, x0, i0)
out = out.reshape(out.shape[0], -1)
out = out.reshape(out.shape[0], out.shape[1] * self.u)
return out


Original file line number Diff line number Diff line change
@@ -88,6 +88,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
Raises:
RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA.
"""
if any(x.numel() == 0 for x in args):
use_fallback = True # Empty tensors are not supported by the CUDA kernel

if (
args
and args[0].device.type == "cuda"
@@ -285,7 +288,7 @@ def forward(self, *args):
(math.prod(shape), arg.shape[-1])
)
if math.prod(arg.shape[:-1]) > 1
else arg.reshape((1, arg.shape[-1]))
else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1]))
)
for arg in args
]
Original file line number Diff line number Diff line change
@@ -30,7 +30,8 @@
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32])
@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir])
@pytest.mark.parametrize("original_mace", [True, False])
def test_symmetric_contraction(dtype, layout, original_mace):
@pytest.mark.parametrize("batch", [0, 32])
def test_symmetric_contraction(dtype, layout, original_mace, batch):
mul = 64
irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e")
irreps_out = mul * cue.Irreps("O3", "0e + 1o")
@@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace):
original_mace=original_mace,
)

Z = 32
x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda()
indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda()
x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda()
indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda()

out = m(x, indices)
assert out.shape == (Z, irreps_out.dim)
assert out.shape == (batch, irreps_out.dim)


def from64(shape: tuple[int, ...], data: str) -> torch.Tensor:
6 changes: 4 additions & 2 deletions cuequivariance_torch/tests/operations/tp_channel_wise_test.py
Original file line number Diff line number Diff line change
@@ -31,12 +31,14 @@
@pytest.mark.parametrize("irreps3", list_of_irreps)
@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir])
@pytest.mark.parametrize("use_fallback", [False, True])
@pytest.mark.parametrize("batch", [0, 32])
def test_channel_wise(
irreps1: cue.Irreps,
irreps2: cue.Irreps,
irreps3: cue.Irreps,
layout: cue.IrrepsLayout,
use_fallback: bool,
batch: int,
):
m = cuet.ChannelWiseTensorProduct(
irreps1,
@@ -49,8 +51,8 @@ def test_channel_wise(
dtype=torch.float64,
)

x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda()
x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda()
x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda()
x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda()

out1 = m(x1, x2, use_fallback=use_fallback)

2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Change Log
# Changelog

```{include} ../CHANGELOG.md