Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Compatibility with jit.script and torch.compile: COMPLETE #40

Merged
merged 55 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ba9580a
test and quick fix for zero batch
mariogeiger Nov 20, 2024
0bfada9
trigger uniform 1d in test
mariogeiger Nov 20, 2024
fd097c6
satisfy linter
mariogeiger Nov 21, 2024
251fc4d
from typing import
mariogeiger Nov 21, 2024
3498a32
determine math_dtype earlier
mariogeiger Nov 21, 2024
7f3cf05
warning with pip commands
mariogeiger Nov 21, 2024
2624335
remove unused argument
mariogeiger Nov 21, 2024
91f7fce
changelog
mariogeiger Nov 21, 2024
4401048
list of inputs
mariogeiger Nov 21, 2024
ad2db8d
add Fixed subtite
mariogeiger Nov 21, 2024
dca96a8
Merge branch 'zero-batch' into list-inputs
mariogeiger Nov 21, 2024
889051a
changelog
mariogeiger Nov 21, 2024
c23816a
Merge branch 'main' into list-inputs
mariogeiger Nov 21, 2024
0487d77
Merge branch 'main' into list-inputs
mariogeiger Dec 3, 2024
bc6b405
add test for torch.jit.script
mariogeiger Dec 3, 2024
c8de185
fix
mariogeiger Dec 3, 2024
5e00b37
Merge branch 'list-inputs' into jit-script
mariogeiger Dec 3, 2024
16e4450
remove keyword-only and import in the forward
mariogeiger Dec 3, 2024
e979b0f
Merge branch 'main' into jit-script
mariogeiger Dec 4, 2024
b2c4fbb
low lvl script tests
mariogeiger Dec 4, 2024
4669a86
TensorProduct working with script()
borisfom Dec 4, 2024
dc9d5b0
add 4 operands tests
mariogeiger Dec 4, 2024
334b460
Unit tests run
borisfom Dec 5, 2024
79e7c5f
Restoring debug logging
borisfom Dec 5, 2024
46a0478
Merge branch 'jit-script' into jit-script
borisfom Dec 5, 2024
401fd53
Merge branch 'jit-script' of github.com:NVIDIA/cuEquivariance into ji…
borisfom Dec 5, 2024
8fce54b
Merge remote-tracking branch 'b/jit-script' into jit-script
borisfom Dec 5, 2024
6c5cdb0
Parameterized script test
borisfom Dec 5, 2024
e21c45f
Fixed transpose for script(), script_test successful
borisfom Dec 5, 2024
779dd9c
Fixed input mutation
borisfom Dec 5, 2024
c315857
Fixed tests
borisfom Dec 6, 2024
ab590c8
format with black
mariogeiger Dec 6, 2024
ec1eb27
format with black
mariogeiger Dec 6, 2024
faf235e
fix tests
mariogeiger Dec 6, 2024
c476af9
fix missing parenthesis
mariogeiger Dec 6, 2024
994b8d9
fix tests: increase torch._dynamo.config.cache_size_limit
mariogeiger Dec 6, 2024
f240eb8
fix docstring tests
mariogeiger Dec 6, 2024
fbfb9d0
replace == by is
mariogeiger Dec 6, 2024
dc20be5
clean use_fallback conditions
mariogeiger Dec 6, 2024
4b201c3
fix
mariogeiger Dec 6, 2024
b5b59b8
fix
mariogeiger Dec 6, 2024
72baf17
Export test added, scripting fallback attempt
borisfom Dec 7, 2024
5a94b09
Merge remote-tracking branch 'b/jit-script' into jit-script
borisfom Dec 7, 2024
6bdf924
Merge branch 'main' into jit-script
mariogeiger Dec 9, 2024
8d31929
enable tests on cpu
mariogeiger Dec 9, 2024
8afa056
fix tests
mariogeiger Dec 9, 2024
09bbc8d
fix ruff
mariogeiger Dec 9, 2024
9c38168
fix
mariogeiger Dec 9, 2024
de9af8f
fix docstring tests
mariogeiger Dec 9, 2024
999a31d
add -x to tests
mariogeiger Dec 9, 2024
905e716
changelog
mariogeiger Dec 9, 2024
975e9c8
test
mariogeiger Dec 9, 2024
093e8e4
move utils into test file
mariogeiger Dec 9, 2024
2712f54
fix
mariogeiger Dec 9, 2024
008ee3d
Merge branch 'main' into jit-script
mariogeiger Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Latest Changes

### Added

- Partial support of `torch.jit.script` and `torch.compile`

### Changed

- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class FullyConnectedTensorProductConv(nn.Module):
mlp_channels (Sequence of int, optional): A sequence of integers defining the number of neurons in each layer in MLP before the output layer. If None, no MLP will be added. The input layer contains edge embeddings and node scalar features. Defaults to None.
mlp_activation (``nn.Module`` or Sequence of ``nn.Module``, optional): A sequence of functions to be applied in between linear layers in MLP, e.g., ``nn.Sequential(nn.ReLU(), nn.Dropout(0.4))``. Defaults to ``nn.GELU()``.
layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.

Examples:
>>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o")
Expand All @@ -66,29 +70,9 @@ class FullyConnectedTensorProductConv(nn.Module):
having 16 channels. edge_emb.size(1) must match the size of the input layer: 6

>>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul).cuda()
... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul)
>>> conv1
FullyConnectedTensorProductConv(
(tp): FullyConnectedTensorProduct(
shared_weights=False, internal_weights=False, weight_numel=64
(f): EquivariantTensorProduct(
EquivariantTensorProduct(64x0e x 4x0e+4x1o x 0e+1o -> 4x0e+4x1o)
(transpose_in): ModuleList(
(0-2): 3 x TransposeIrrepsLayout((irrep,mul) -> (irrep,mul))
)
(transpose_out): TransposeIrrepsLayout((irrep,mul) -> (irrep,mul))
(tp): TensorProduct(uvw,iu,jv,kw+ijk sizes=64,16,4,16 num_segments=4,2,2,2 num_paths=4 i={1, 3} j={1, 3} k={1, 3} u=4 v=1 w=4 (with CUDA kernel))
)
)
(batch_norm): BatchNorm(4x0e+4x1o, layout=(irrep,mul), eps=1e-05, momentum=0.1)
(mlp): Sequential(
(0): Linear(in_features=6, out_features=16, bias=True)
(1): ReLU()
(2): Linear(in_features=16, out_features=16, bias=True)
(3): ReLU()
(4): Linear(in_features=16, out_features=64, bias=True)
)
)
FullyConnectedTensorProductConv(...)
>>> # out = conv1(src_features, edge_sh, edge_emb, graph)

**Case 2**: If edge_emb is constructed by concatenating scalar features from
Expand All @@ -108,7 +92,7 @@ class FullyConnectedTensorProductConv(nn.Module):
**Case 3**: No MLP, edge_emb will be directly used as the tensor product weights:

>>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps,
... mlp_channels=None, layout=cue.ir_mul).cuda()
... mlp_channels=None, layout=cue.ir_mul)
>>> # out = conv3(src_features, edge_sh, edge_emb, graph)
"""

Expand All @@ -121,6 +105,7 @@ def __init__(
mlp_channels: Optional[Sequence[int]] = None,
mlp_activation: Union[nn.Module, Sequence[nn.Module], None] = nn.GELU(),
layout: cue.IrrepsLayout = None, # e3nn_compat_mode
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
Expand All @@ -141,6 +126,7 @@ def __init__(
out_irreps,
layout=self.layout,
shared_weights=False,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

Expand Down
13 changes: 7 additions & 6 deletions cuequivariance_torch/cuequivariance_torch/operations/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Linear(torch.nn.Module):
layout (IrrepsLayout, optional): The layout of the irreducible representations, by default ``cue.mul_ir``. This is the layout used in the e3nn library.
shared_weights (bool, optional): Whether to use shared weights, by default True.
internal_weights (bool, optional): Whether to use internal weights, by default True if shared_weights is True, otherwise False.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
"""

def __init__(
Expand All @@ -47,6 +51,7 @@ def __init__(
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
Expand Down Expand Up @@ -84,6 +89,7 @@ def __init__(
layout_out=layout_out,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

Expand All @@ -94,18 +100,13 @@ def forward(
self,
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
*,
use_fallback: Optional[bool] = None,
) -> torch.Tensor:
"""
Forward pass of the linear layer.

Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor, optional): The weight tensor. If None, the internal weight tensor is used.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.

Returns:
torch.Tensor: The output tensor after applying the linear transformation.
Expand All @@ -126,4 +127,4 @@ def forward(
if not self.shared_weights and weight.ndim != 2:
raise ValueError("Weights should be 2D tensor")

return self.f([weight, x], use_fallback=use_fallback)
return self.f([weight, x])
17 changes: 11 additions & 6 deletions cuequivariance_torch/cuequivariance_torch/operations/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
layout_out: Optional[cue.IrrepsLayout] = None,
device: Optional[torch.device] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
Expand All @@ -60,6 +61,7 @@ def __init__(
layout_out=layout_out,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

Expand All @@ -69,8 +71,6 @@ def forward(
beta: torch.Tensor,
alpha: torch.Tensor,
x: torch.Tensor,
*,
use_fallback: Optional[bool] = None,
) -> torch.Tensor:
"""
Forward pass of the rotation layer.
Expand All @@ -80,9 +80,6 @@ def forward(
beta (torch.Tensor): The beta angles. Second rotation around the x-axis.
alpha (torch.Tensor): The alpha angles. Third rotation around the y-axis.
x (torch.Tensor): The input tensor.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.

Returns:
torch.Tensor: The rotated tensor.
Expand All @@ -97,7 +94,6 @@ def forward(

return self.f(
[encodings_gamma, encodings_beta, encodings_alpha, x],
use_fallback=use_fallback,
)


Expand Down Expand Up @@ -159,6 +155,11 @@ class Inversion(torch.nn.Module):
Args:
irreps (Irreps): The irreducible representations of the tensor to invert.
layout (IrrepsLayout, optional): The memory layout of the tensor, ``cue.ir_mul`` is preferred.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.

optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.
"""

def __init__(
Expand All @@ -170,6 +171,8 @@ def __init__(
layout_out: Optional[cue.IrrepsLayout] = None,
device: Optional[torch.device] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
(irreps,) = default_irreps(irreps)
Expand All @@ -187,6 +190,8 @@ def __init__(
layout_out=layout_out,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def spherical_harmonics(
ls: list[int],
vectors: torch.Tensor,
normalize: bool = True,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
) -> torch.Tensor:
r"""Compute the spherical harmonics of the input vectors.
Expand All @@ -33,6 +34,10 @@ def spherical_harmonics(
ls (list of int): List of spherical harmonic degrees.
vectors (torch.Tensor): Input vectors of shape (..., 3).
normalize (bool, optional): Whether to normalize the input vectors. Defaults to True.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.

optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.

Returns:
Expand All @@ -53,6 +58,7 @@ def spherical_harmonics(
layout=cue.ir_mul,
device=x.device,
math_dtype=x.dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)
y = m([x])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class SymmetricContraction(torch.nn.Module):
layout (IrrepsLayout, optional): The layout of the input and output irreps. If not provided, a default layout is used.
math_dtype (torch.dtype, optional): The data type for mathematical operations. If not specified, the default data type
from the torch environment is used.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.

Examples:
>>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o")
Expand All @@ -48,6 +52,7 @@ class SymmetricContraction(torch.nn.Module):

The argument `original_mace` can be set to `True` to emulate the original MACE implementation.

>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
>>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e")
>>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o")
>>> # OLD FUNCTION DEFINITION:
Expand All @@ -67,14 +72,15 @@ class SymmetricContraction(torch.nn.Module):
... layout_out=cue.mul_ir,
... original_mace=True,
... dtype=torch.float64,
... device=device,
... )

Then the execution is as follows:

>>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64)
>>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device)
>>> # with node_attrs_index being the index version of node_attrs, sth like:
>>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int()
>>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32)
>>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device)
>>> # OLD CALL:
>>> # symmetric_contractions_old(node_feats, node_attrs)
>>> # NEW CALL:
Expand Down Expand Up @@ -102,6 +108,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
math_dtype: Optional[torch.dtype] = None,
original_mace: bool = False,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
Expand Down Expand Up @@ -147,6 +154,7 @@ def __init__(
layout_out=layout_out,
device=device,
math_dtype=math_dtype or dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

Expand All @@ -160,8 +168,6 @@ def forward(
self,
x: torch.Tensor,
indices: torch.Tensor,
*,
use_fallback: Optional[bool] = None,
) -> torch.Tensor:
"""
Perform the forward pass of the symmetric contraction operation.
Expand All @@ -170,9 +176,6 @@ def forward(
x (torch.Tensor): The input tensor. It should have shape (..., irreps_in.dim).
indices (torch.Tensor): The index of the weight to use for each batch element.
It should have shape (...).
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.

Returns:
torch.Tensor: The output tensor. It has shape (batch, irreps_out.dim).
Expand All @@ -188,4 +191,4 @@ def forward(
weight = self.weight
weight = weight.flatten(1)

return self.f([weight, x], indices=indices, use_fallback=use_fallback)
return self.f([weight, x], indices=indices)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class ChannelWiseTensorProduct(torch.nn.Module):
layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn.
shared_weights (bool, optional): Whether to share weights across the batch dimension. Default is True.
internal_weights (bool, optional): Whether to create module parameters for weights. Default is None.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.
optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None.

Note:
In e3nn there was a irrep_normalization and path_normalization parameters.
Expand All @@ -54,6 +58,7 @@ def __init__(
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
math_dtype: Optional[torch.dtype] = None,
use_fallback: Optional[bool] = None,
optimize_fallback: Optional[bool] = None,
):
super().__init__()
Expand Down Expand Up @@ -95,6 +100,7 @@ def __init__(
layout_out=layout_out,
device=device,
math_dtype=math_dtype,
use_fallback=use_fallback,
optimize_fallback=optimize_fallback,
)

Expand All @@ -110,8 +116,6 @@ def forward(
x1: torch.Tensor,
x2: torch.Tensor,
weight: Optional[torch.Tensor] = None,
*,
use_fallback: Optional[bool] = None,
) -> torch.Tensor:
"""
Perform the forward pass of the fully connected tensor product operation.
Expand All @@ -122,9 +126,6 @@ def forward(
weight (torch.Tensor, optional): Weights for the tensor product. It should have the shape (batch_size, weight_numel)
if shared_weights is False, or (weight_numel,) if shared_weights is True.
If None, the internal weights are used.
use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available.
If `False`, a CUDA kernel will be used, and an exception is raised if it's not available.
If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability.

Returns:
torch.Tensor:
Expand All @@ -147,4 +148,4 @@ def forward(
if not self.shared_weights and weight.ndim != 2:
raise ValueError("Weights should be 2D tensor")

return self.f([weight, x1, x2], use_fallback=use_fallback)
return self.f([weight, x1, x2])
Loading
Loading