diff --git a/CHANGELOG.md b/CHANGELOG.md index 867796d..3ab694f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Latest Changes +### Added + +- Partial support of `torch.jit.script` and `torch.compile` + ### Changed - `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input. diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index 18df53f..b6a6fb0 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -56,6 +56,10 @@ class FullyConnectedTensorProductConv(nn.Module): mlp_channels (Sequence of int, optional): A sequence of integers defining the number of neurons in each layer in MLP before the output layer. If None, no MLP will be added. The input layer contains edge embeddings and node scalar features. Defaults to None. mlp_activation (``nn.Module`` or Sequence of ``nn.Module``, optional): A sequence of functions to be applied in between linear layers in MLP, e.g., ``nn.Sequential(nn.ReLU(), nn.Dropout(0.4))``. Defaults to ``nn.GELU()``. layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Examples: >>> in_irreps = cue.Irreps("O3", "4x0e + 4x1o") @@ -66,29 +70,9 @@ class FullyConnectedTensorProductConv(nn.Module): having 16 channels. edge_emb.size(1) must match the size of the input layer: 6 >>> conv1 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, - ... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul).cuda() + ... mlp_channels=[6, 16, 16], mlp_activation=nn.ReLU(), layout=cue.ir_mul) >>> conv1 - FullyConnectedTensorProductConv( - (tp): FullyConnectedTensorProduct( - shared_weights=False, internal_weights=False, weight_numel=64 - (f): EquivariantTensorProduct( - EquivariantTensorProduct(64x0e x 4x0e+4x1o x 0e+1o -> 4x0e+4x1o) - (transpose_in): ModuleList( - (0-2): 3 x TransposeIrrepsLayout((irrep,mul) -> (irrep,mul)) - ) - (transpose_out): TransposeIrrepsLayout((irrep,mul) -> (irrep,mul)) - (tp): TensorProduct(uvw,iu,jv,kw+ijk sizes=64,16,4,16 num_segments=4,2,2,2 num_paths=4 i={1, 3} j={1, 3} k={1, 3} u=4 v=1 w=4 (with CUDA kernel)) - ) - ) - (batch_norm): BatchNorm(4x0e+4x1o, layout=(irrep,mul), eps=1e-05, momentum=0.1) - (mlp): Sequential( - (0): Linear(in_features=6, out_features=16, bias=True) - (1): ReLU() - (2): Linear(in_features=16, out_features=16, bias=True) - (3): ReLU() - (4): Linear(in_features=16, out_features=64, bias=True) - ) - ) + FullyConnectedTensorProductConv(...) >>> # out = conv1(src_features, edge_sh, edge_emb, graph) **Case 2**: If edge_emb is constructed by concatenating scalar features from @@ -108,7 +92,7 @@ class FullyConnectedTensorProductConv(nn.Module): **Case 3**: No MLP, edge_emb will be directly used as the tensor product weights: >>> conv3 = FullyConnectedTensorProductConv(in_irreps, sh_irreps, out_irreps, - ... mlp_channels=None, layout=cue.ir_mul).cuda() + ... mlp_channels=None, layout=cue.ir_mul) >>> # out = conv3(src_features, edge_sh, edge_emb, graph) """ @@ -121,6 +105,7 @@ def __init__( mlp_channels: Optional[Sequence[int]] = None, mlp_activation: Union[nn.Module, Sequence[nn.Module], None] = nn.GELU(), layout: cue.IrrepsLayout = None, # e3nn_compat_mode + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -141,6 +126,7 @@ def __init__( out_irreps, layout=self.layout, shared_weights=False, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index a9f16c1..41c7e89 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -32,6 +32,10 @@ class Linear(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the irreducible representations, by default ``cue.mul_ir``. This is the layout used in the e3nn library. shared_weights (bool, optional): Whether to use shared weights, by default True. internal_weights (bool, optional): Whether to use internal weights, by default True if shared_weights is True, otherwise False. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. """ def __init__( @@ -47,6 +51,7 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -84,6 +89,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -94,8 +100,6 @@ def forward( self, x: torch.Tensor, weight: Optional[torch.Tensor] = None, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Forward pass of the linear layer. @@ -103,9 +107,6 @@ def forward( Args: x (torch.Tensor): The input tensor. weight (torch.Tensor, optional): The weight tensor. If None, the internal weight tensor is used. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: The output tensor after applying the linear transformation. @@ -126,4 +127,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f([weight, x], use_fallback=use_fallback) + return self.f([weight, x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index 435177f..cc2356f 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -40,6 +40,7 @@ def __init__( layout_out: Optional[cue.IrrepsLayout] = None, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -60,6 +61,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -69,8 +71,6 @@ def forward( beta: torch.Tensor, alpha: torch.Tensor, x: torch.Tensor, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Forward pass of the rotation layer. @@ -80,9 +80,6 @@ def forward( beta (torch.Tensor): The beta angles. Second rotation around the x-axis. alpha (torch.Tensor): The alpha angles. Third rotation around the y-axis. x (torch.Tensor): The input tensor. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: The rotated tensor. @@ -97,7 +94,6 @@ def forward( return self.f( [encodings_gamma, encodings_beta, encodings_alpha, x], - use_fallback=use_fallback, ) @@ -159,6 +155,11 @@ class Inversion(torch.nn.Module): Args: irreps (Irreps): The irreducible representations of the tensor to invert. layout (IrrepsLayout, optional): The memory layout of the tensor, ``cue.ir_mul`` is preferred. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. """ def __init__( @@ -170,6 +171,8 @@ def __init__( layout_out: Optional[cue.IrrepsLayout] = None, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, + optimize_fallback: Optional[bool] = None, ): super().__init__() (irreps,) = default_irreps(irreps) @@ -187,6 +190,8 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py index 8e324a6..da6be5d 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/spherical_harmonics.py @@ -25,6 +25,7 @@ def spherical_harmonics( ls: list[int], vectors: torch.Tensor, normalize: bool = True, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ) -> torch.Tensor: r"""Compute the spherical harmonics of the input vectors. @@ -33,6 +34,10 @@ def spherical_harmonics( ls (list of int): List of spherical harmonic degrees. vectors (torch.Tensor): Input vectors of shape (..., 3). normalize (bool, optional): Whether to normalize the input vectors. Defaults to True. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Returns: @@ -53,6 +58,7 @@ def spherical_harmonics( layout=cue.ir_mul, device=x.device, math_dtype=x.dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) y = m([x]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 35f1b56..a91a72a 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -38,6 +38,10 @@ class SymmetricContraction(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the input and output irreps. If not provided, a default layout is used. math_dtype (torch.dtype, optional): The data type for mathematical operations. If not specified, the default data type from the torch environment is used. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Examples: >>> irreps_in = cue.Irreps("O3", "32x0e + 32x1o") @@ -48,6 +52,7 @@ class SymmetricContraction(torch.nn.Module): The argument `original_mace` can be set to `True` to emulate the original MACE implementation. + >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> feats_irreps = cue.Irreps("O3", "32x0e + 32x1o + 32x2e") >>> target_irreps = cue.Irreps("O3", "32x0e + 32x1o") >>> # OLD FUNCTION DEFINITION: @@ -67,14 +72,15 @@ class SymmetricContraction(torch.nn.Module): ... layout_out=cue.mul_ir, ... original_mace=True, ... dtype=torch.float64, + ... device=device, ... ) Then the execution is as follows: - >>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64) + >>> node_feats = torch.randn(128, 32, feats_irreps.dim // 32, dtype=torch.float64, device=device) >>> # with node_attrs_index being the index version of node_attrs, sth like: >>> # node_attrs_index = torch.nonzero(node_attrs)[:, 1].int() - >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32) + >>> node_attrs_index = torch.randint(0, 10, (128,), dtype=torch.int32, device=device) >>> # OLD CALL: >>> # symmetric_contractions_old(node_feats, node_attrs) >>> # NEW CALL: @@ -102,6 +108,7 @@ def __init__( dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, original_mace: bool = False, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -147,6 +154,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype or dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -160,8 +168,6 @@ def forward( self, x: torch.Tensor, indices: torch.Tensor, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Perform the forward pass of the symmetric contraction operation. @@ -170,9 +176,6 @@ def forward( x (torch.Tensor): The input tensor. It should have shape (..., irreps_in.dim). indices (torch.Tensor): The index of the weight to use for each batch element. It should have shape (...). - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: The output tensor. It has shape (batch, irreps_out.dim). @@ -188,4 +191,4 @@ def forward( weight = self.weight weight = weight.flatten(1) - return self.f([weight, x], indices=indices, use_fallback=use_fallback) + return self.f([weight, x], indices=indices) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index 5533a35..a6ac80f 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -33,6 +33,10 @@ class ChannelWiseTensorProduct(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn. shared_weights (bool, optional): Whether to share weights across the batch dimension. Default is True. internal_weights (bool, optional): Whether to create module parameters for weights. Default is None. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Note: In e3nn there was a irrep_normalization and path_normalization parameters. @@ -54,6 +58,7 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -95,6 +100,7 @@ def __init__( layout_out=layout_out, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) @@ -110,8 +116,6 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, weight: Optional[torch.Tensor] = None, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Perform the forward pass of the fully connected tensor product operation. @@ -122,9 +126,6 @@ def forward( weight (torch.Tensor, optional): Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: @@ -147,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f([weight, x1, x2], use_fallback=use_fallback) + return self.f([weight, x1, x2]) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index cfae6be..44c781d 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -33,6 +33,10 @@ class FullyConnectedTensorProduct(torch.nn.Module): layout (IrrepsLayout, optional): The layout of the input and output irreps. Default is ``cue.mul_ir`` which is the layout corresponding to e3nn. shared_weights (bool, optional): Whether to share weights across the batch dimension. Default is True. internal_weights (bool, optional): Whether to create module parameters for weights. Default is None. + use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. + If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): Whether to optimize fallback. Defaults to None. Note: In e3nn there was a irrep_normalization and path_normalization parameters. @@ -54,6 +58,7 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -111,8 +116,6 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, weight: Optional[torch.Tensor] = None, - *, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ Perform the forward pass of the fully connected tensor product operation. @@ -123,9 +126,6 @@ def forward( weight (torch.Tensor, optional): Weights for the tensor product. It should have the shape (batch_size, weight_numel) if shared_weights is False, or (weight_numel,) if shared_weights is True. If None, the internal weights are used. - use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: @@ -148,4 +148,4 @@ def forward( if not self.shared_weights and weight.ndim != 2: raise ValueError("Weights should be 2D tensor") - return self.f([weight, x1, x2], use_fallback=use_fallback) + return self.f([weight, x1, x2]) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index e4ce239..16701ff 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -21,6 +21,93 @@ from cuequivariance.irreps_array.misc_ui import default_layout +class Dispatcher(torch.nn.Module): + def __init__(self, tp): + super().__init__() + self.tp = tp + + +class Transpose1Dispatcher(Dispatcher): + def forward(self, inputs: List[torch.Tensor]): + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + return ret + + +class Transpose2Dispatcher(Dispatcher): + def forward(self, inputs: List[torch.Tensor]): + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + ret[1] = self.tp[1](ret[1]) + return ret + + +class Transpose3Dispatcher(Dispatcher): + def forward(self, inputs: List[torch.Tensor]): + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + ret[1] = self.tp[1](ret[1]) + ret[2] = self.tp[2](ret[2]) + return ret + + +class Transpose4Dispatcher(Dispatcher): + def forward(self, inputs: List[torch.Tensor]): + ret = inputs.copy() + ret[0] = self.tp[0](ret[0]) + ret[1] = self.tp[1](ret[1]) + ret[2] = self.tp[2](ret[2]) + ret[3] = self.tp[3](ret[3]) + return ret + + +TRANSPOSE_DISPATCHERS = [ + Transpose1Dispatcher, + Transpose2Dispatcher, + Transpose3Dispatcher, + Transpose4Dispatcher, +] + + +class TPDispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if indices is not None: + # TODO: at some point we will have kernel for this + assert len(inputs) >= 1 + inputs[0] = inputs[0][indices] + return self.tp(inputs) + + +class SymmetricTPDispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert indices is None + return self.tp(inputs[0]) + + +class IWeightedSymmetricTPDispatcher(Dispatcher): + def forward( + self, + inputs: List[torch.Tensor], + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x0, x1 = inputs + if indices is None: + torch._assert( + x0.ndim == 2, + f"Expected x0 to have shape (batch, dim), got {x0.shape}", + ) + indices = torch.arange(x1.shape[0], dtype=torch.int32, device=x1.device) + return self.tp(x0, indices, x1) + + class EquivariantTensorProduct(torch.nn.Module): r"""Equivariant tensor product. @@ -31,29 +118,29 @@ class EquivariantTensorProduct(torch.nn.Module): layout_out (IrrepsLayout): layout for output. device (torch.device): device of the Module. math_dtype (torch.dtype): dtype for internal computations. + use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. optimize_fallback (bool): whether to optimize the fallback implementation. + Raises: + RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. Examples: + >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") >>> e = cue.descriptors.fully_connected_tensor_product( ... cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1") ... ) - >>> w = torch.ones(e.inputs[0].irreps.dim) - >>> x1 = torch.ones(17, e.inputs[1].irreps.dim) - >>> x2 = torch.ones(17, e.inputs[2].irreps.dim) - >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul) + >>> w = torch.ones(e.inputs[0].irreps.dim, device=device) + >>> x1 = torch.ones(17, e.inputs[1].irreps.dim, device=device) + >>> x2 = torch.ones(17, e.inputs[2].irreps.dim, device=device) + >>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) >>> tp([w, x1, x2]) - tensor([[0., 0., 0., 0., 0., 0.], - ... - [0., 0., 0., 0., 0., 0.]]) + tensor([[0., 0., 0., 0., 0., 0.],...) You can optionally index the first input tensor: - >>> w = torch.ones(3, e.inputs[0].irreps.dim) + >>> w = torch.ones(3, e.inputs[0].irreps.dim, device=device) >>> indices = torch.randint(3, (17,)) >>> tp([w, x1, x2], indices=indices) - tensor([[0., 0., 0., 0., 0., 0.], - ... - [0., 0., 0., 0., 0., 0.]]) + tensor([[0., 0., 0., 0., 0., 0.],...) """ def __init__( @@ -67,6 +154,7 @@ def __init__( layout_out: Optional[cue.IrrepsLayout] = None, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -84,52 +172,62 @@ def __init__( self.layout_in = layout_in = tuple(map(default_layout, layout_in)) self.layout_out = layout_out = default_layout(layout_out) - self.transpose_in = torch.nn.ModuleList() + transpose_in = torch.nn.ModuleList() for layout_used, input_expected in zip(layout_in, e.inputs): - self.transpose_in.append( + transpose_in.append( cuet.TransposeIrrepsLayout( input_expected.irreps, source=layout_used, target=input_expected.layout, device=device, + use_fallback=use_fallback, ) ) + + # script() requires literal addressing and fails to eliminate dead branches + self.transpose_in = TRANSPOSE_DISPATCHERS[len(transpose_in) - 1](transpose_in) + self.transpose_out = cuet.TransposeIrrepsLayout( e.output.irreps, source=e.output.layout, target=layout_out, device=device, + use_fallback=use_fallback, ) if any(d.num_operands != e.num_inputs + 1 for d in e.ds): - self.tp = None - if e.num_inputs == 1: - self.symm_tp = cuet.SymmetricTensorProduct( - e.ds, - device=device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + self.tp = SymmetricTPDispatcher( + cuet.SymmetricTensorProduct( + e.ds, + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, + ) ) elif e.num_inputs == 2: - self.symm_tp = cuet.IWeightedSymmetricTensorProduct( - e.ds, - device=device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + self.tp = IWeightedSymmetricTPDispatcher( + cuet.IWeightedSymmetricTensorProduct( + e.ds, + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, + ) ) else: raise NotImplementedError("This should not happen") else: - [d] = e.ds - - self.tp = cuet.TensorProduct( - d, - device=device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, + self.tp = TPDispatcher( + cuet.TensorProduct( + e.ds[0], + device=device, + math_dtype=math_dtype, + use_fallback=use_fallback, + optimize_fallback=optimize_fallback, + ) ) - self.symm_tp = None self.operands_dims = [op.irreps.dim for op in e.operands] @@ -140,60 +238,22 @@ def forward( self, inputs: List[torch.Tensor], indices: Optional[torch.Tensor] = None, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: """ If ``indices`` is not None, the first input is indexed by ``indices``. """ - inputs: List[torch.Tensor] = list(inputs) - assert len(inputs) == len(self.etp.inputs) + # assert len(inputs) == len(self.etp.inputs) for a, dim in zip(inputs, self.operands_dims): assert a.shape[-1] == dim # Transpose inputs - inputs = [ - t(a, use_fallback=use_fallback) for t, a in zip(self.transpose_in, inputs) - ] + inputs = self.transpose_in(inputs) # Compute tensor product - output = None - - if self.tp is not None: - if indices is not None: - # TODO: at some point we will have kernel for this - assert len(inputs) >= 1 - inputs[0] = inputs[0][indices] - output = self.tp(inputs, use_fallback=use_fallback) - - if self.symm_tp is not None: - if len(inputs) == 1: - assert indices is None - output = self.symm_tp(inputs[0], use_fallback=use_fallback) - - if len(inputs) == 2: - [x0, x1] = inputs - if indices is None: - torch._assert( - x0.ndim == 2, - f"Expected x0 to have shape (batch, dim), got {x0.shape}", - ) - if x0.shape[0] == 1: - indices = torch.zeros( - (x1.shape[0],), dtype=torch.int32, device=x1.device - ) - elif x0.shape[0] == x1.shape[0]: - indices = torch.arange( - x1.shape[0], dtype=torch.int32, device=x1.device - ) - - if indices is not None: - output = self.symm_tp(x0, indices, x1, use_fallback=use_fallback) - - if output is None: - raise NotImplementedError("This should not happen") + output = self.tp(inputs, indices) # Transpose output - output = self.transpose_out(output, use_fallback=use_fallback) + output = self.transpose_out(output) return output diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 79d92a6..dc61230 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import math from typing import Optional import torch @@ -21,6 +20,7 @@ import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet +from cuequivariance_torch.primitives.tensor_product import broadcast_shapes, prod logger = logging.getLogger(__name__) @@ -41,6 +41,7 @@ def __init__( *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() @@ -55,6 +56,7 @@ def __init__( d0, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) else: @@ -82,12 +84,11 @@ def __init__( descriptors, device=device, math_dtype=math_dtype, + use_fallback=use_fallback, optimize_fallback=optimize_fallback, ) - def forward( - self, x0: torch.Tensor, use_fallback: Optional[bool] = None - ) -> torch.Tensor: + def forward(self, x0: torch.Tensor) -> torch.Tensor: r""" Perform the forward pass of the indexed symmetric tensor product operation. @@ -106,7 +107,6 @@ def forward( torch.ones((1, 1), dtype=x0.dtype, device=x0.device), torch.zeros((x0.shape[0],), dtype=torch.int32, device=x0.device), x0, - use_fallback=use_fallback, ) if self.f0 is not None: out += self.f0([]) @@ -134,37 +134,54 @@ def __init__( *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() + if math_dtype is None: + math_dtype = torch.get_default_dtype() + _check_descriptors(descriptors) self.descriptors = descriptors - try: - self.f_cuda = CUDAKernel(descriptors, device, math_dtype) - except NotImplementedError as e: - logger.info(f"Failed to initialize CUDA implementation: {e}") - self.f_cuda = None - except ImportError as e: - logger.warning(f"Failed to initialize CUDA implementation: {e}") - self.f_cuda = None - - self.f_fx = FallbackImpl( - descriptors, - device, - math_dtype=math_dtype, - optimize_fallback=optimize_fallback, - ) - d = next(d for d in descriptors if d.num_operands >= 3) self.x0_size = d.operands[0].size self.x1_size = d.operands[1].size self.x2_size = d.operands[-1].size + self.has_cuda = False + + self.f = None + + if use_fallback is None or use_fallback is False: + try: + self.f = CUDAKernel(descriptors, device, math_dtype) + self.has_cuda = True + return + except NotImplementedError as e: + logger.info(f"Failed to initialize CUDA implementation: {e}") + except ImportError as e: + logger.warning(f"Failed to initialize CUDA implementation: {e}") + + if use_fallback is False and not self.has_cuda: + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available!" + ) + + if self.f is None: + self.f = FallbackImpl( + descriptors, + device, + math_dtype=math_dtype, + optimize_fallback=optimize_fallback, + ) + def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.f_cuda is not None else "(without CUDA kernel)" + "(with CUDA kernel)" + if self.has_cuda is not None + else "(without CUDA kernel)" ) return f"IWeightedSymmetricTensorProduct({has_cuda_kernel})" @@ -173,7 +190,6 @@ def forward( x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor, - use_fallback: Optional[bool] = None, ) -> torch.Tensor: r""" Perform the forward pass of the indexed symmetric tensor product operation. @@ -187,10 +203,6 @@ def forward( The index tensor for the first operand. It should have the shape (...). x1 : torch.Tensor The repeated input tensor. It should have the shape (..., x1_size). - use_fallback : Optional[bool], optional - If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. - If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns ------- @@ -203,28 +215,11 @@ def forward( x0.ndim == 2, f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}", ) - shape = torch.broadcast_shapes(i0.shape, x1.shape[:-1]) - i0 = i0.expand(shape).reshape((math.prod(shape),)) - x1 = x1.expand(shape + (x1.shape[-1],)).reshape( - (math.prod(shape), x1.shape[-1]) - ) - - if ( - x0.device.type == "cuda" - and self.f_cuda is not None - and (use_fallback is not True) - ): - out = self.f_cuda(x0, i0, x1) - out = out.reshape(shape + (self.x2_size,)) - return out - - if use_fallback is False: - if self.f_cuda is not None: - raise RuntimeError("CUDA kernel available but input is not on CUDA") - else: - raise RuntimeError("No CUDA kernel available") + shape = broadcast_shapes([i0.shape, x1.shape[:-1]]) + i0 = i0.expand(shape).reshape((prod(shape),)) + x1 = x1.expand(shape + (x1.shape[-1],)).reshape((prod(shape), x1.shape[-1])) - out = self.f_fx(x0, i0, x1) + out = self.f(x0, i0, x1) out = out.reshape(shape + (self.x2_size,)) return out @@ -258,13 +253,10 @@ def __init__( self, stps: list[stp.SegmentedTensorProduct], device: Optional[torch.device], - math_dtype: Optional[torch.dtype], + math_dtype: torch.dtype, ): super().__init__() - if math_dtype is None: - math_dtype = torch.get_default_dtype() - max_degree = max(d.num_operands - 2 for d in stps) if max_degree > 6: raise NotImplementedError("Correlation > 6 is not implemented.") @@ -335,9 +327,10 @@ def forward( i0 = i0.to(torch.int32) x0 = x0.reshape(x0.shape[0], x0.shape[1] // self.u, self.u) x1 = x1.reshape(x1.shape[0], x1.shape[1] // self.u, self.u) - logger.debug( - f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" - ) + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + logger.debug( + f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}" + ) out = self.f(x1, x0, i0) out = out.reshape(out.shape[0], out.shape[1] * self.u) return out @@ -358,6 +351,7 @@ def __init__( d, device=device, math_dtype=math_dtype, + use_fallback=True, optimize_fallback=optimize_fallback, ) for d in stps @@ -368,6 +362,5 @@ def forward( self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor ) -> torch.Tensor: return sum( - f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2), use_fallback=True) - for f in self.fs + f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2)) for f in self.fs ) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 1dceb9e..b436161 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -25,6 +25,53 @@ logger = logging.getLogger(__name__) +def prod(numbers: List[int]): + product = 1 + for num in numbers: + product *= num + return product + + +def broadcast_shapes(shapes: List[List[int]]): + if torch.jit.is_scripting(): + max_len = 0 + for shape in shapes: + if isinstance(shape, int): + if max_len < 1: + max_len = 1 + elif isinstance(shape, (tuple, list)): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + for shape in shapes: + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, (tuple, list)): + for i in range(-1, -1 - len(shape), -1): + if shape[i] < 0: + raise RuntimeError( + "Trying to create tensor with negative dimension ({}): ({})".format( + shape[i], shape[i] + ) + ) + if shape[i] == 1 or shape[i] == result[i]: + continue + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape" + ) + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) + return torch.Size(result) + else: + return torch.functional.broadcast_shapes(*shapes) + + class TensorProduct(torch.nn.Module): """ PyTorch module that computes the last operand of the segmented tensor product defined by the descriptor. @@ -33,7 +80,12 @@ class TensorProduct(torch.nn.Module): descriptor (SegmentedTensorProduct): The descriptor of the segmented tensor product. math_dtype (torch.dtype, optional): The data type of the coefficients and calculations. device (torch.device, optional): The device on which the calculations are performed. + use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. + optimize_fallback (bool, optional): If `True`, the fallback method is optimized. If `False`, the fallback method is used without optimization. + Raises: + RuntimeError: If `use_fallback` is `False` and no CUDA kernel is available. + """ def __init__( @@ -42,41 +94,57 @@ def __init__( *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, + use_fallback: Optional[bool] = None, optimize_fallback: Optional[bool] = None, ): super().__init__() self.descriptor = descriptor - if math_dtype is None: math_dtype = torch.get_default_dtype() + self.f = None + self.has_cuda = False + self.num_operands = descriptor.num_operands - try: - self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype) - except NotImplementedError as e: - logger.info(f"CUDA implementation not available: {e}") - self.f_cuda = None - except ImportError as e: - logger.warning(f"CUDA implementation not available: {e}") - logger.warning( - "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" - "Install it with one of the following commands:\n" - "pip install cuequivariance-ops-torch-cu11\n" - "pip install cuequivariance-ops-torch-cu12" + if use_fallback is None or use_fallback is False: + try: + self.f = _tensor_product_cuda(descriptor, device, math_dtype) + self.has_cuda = True + except NotImplementedError as e: + logger.info(f"CUDA implementation not available: {e}") + except ImportError as e: + logger.warning(f"CUDA implementation not available: {e}") + logger.warning( + "Did you forget to install the CUDA version of cuequivariance-ops-torch?\n" + "Install it with one of the following commands:\n" + "pip install cuequivariance-ops-torch-cu11\n" + "pip install cuequivariance-ops-torch-cu12" + ) + + if use_fallback is False and not self.has_cuda: + raise RuntimeError( + "`use_fallback` is `False` and no CUDA kernel is available!" ) - self.f_cuda = None - self.f_fx = _tensor_product_fx( - descriptor, device, math_dtype, optimize_fallback is True - ) - self._optimize_fallback = optimize_fallback + if self.f is None: + if optimize_fallback is None: + optimize_fallback = False + warnings.warn( + "The fallback method is used but it has not been optimized. " + "Consider setting optimize_fallback=True when creating the TensorProduct module." + ) + + self.f = _tensor_product_fx( + descriptor, device, math_dtype, optimize_fallback + ) + self._optimize_fallback = optimize_fallback def __repr__(self): has_cuda_kernel = ( - "(with CUDA kernel)" if self.f_cuda is not None else "(without CUDA kernel)" + "(with CUDA kernel)" if self.has_cuda else "(without CUDA kernel)" ) return f"TensorProduct({self.descriptor} {has_cuda_kernel})" - def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = None): + def forward(self, inputs: List[torch.Tensor]): r""" Perform the tensor product based on the specified descriptor. @@ -84,9 +152,6 @@ def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = Non inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one. Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size of each operand as defined in the tensor product descriptor. - use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available and the input - is on CUDA. If `False`, a CUDA kernel will be used, and an exception is raised if it's not available or the - input is not on CUDA. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: torch.Tensor: @@ -94,32 +159,11 @@ def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = Non It has a shape of (batch, last_operand_size), where `last_operand_size` is the size of the last operand in the descriptor. - Raises: - RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA. """ - if any(x.numel() == 0 for x in inputs): - use_fallback = True # Empty tensors are not supported by the CUDA kernel - - if ( - inputs - and inputs[0].device.type == "cuda" - and self.f_cuda is not None - and (use_fallback is not True) - ): - return self.f_cuda(*inputs) - - if use_fallback is False: - if self.f_cuda is not None: - raise RuntimeError("CUDA kernel available but input is not on CUDA") - else: - raise RuntimeError("No CUDA kernel available") + # if any(x.numel() == 0 for x in inputs): + # use_fallback = True # Empty tensors are not supported by the CUDA kernel - if self._optimize_fallback is None: - warnings.warn( - "The fallback method is used but it has not been optimized. " - "Consider setting optimize_fallback=True when creating the TensorProduct module." - ) - return self.f_fx(inputs) + return self.f(inputs) def _tensor_product_fx( @@ -189,9 +233,7 @@ def _tensor_product_fx( seg_shape = descriptor.get_segment_shape(-1, path) outputs += [ - out.reshape( - out.shape[: out.ndim - len(seg_shape)] + (math.prod(seg_shape),) - ) + out.reshape(out.shape[: out.ndim - len(seg_shape)] + (prod(seg_shape),)) ] if len(outputs) == 0: @@ -206,7 +248,7 @@ def _tensor_product_fx( for out, path in zip(outputs, descriptor.paths) if path.indices[-1] == i ], - shape=batch_shape + (math.prod(descriptor.operands[-1][i]),), + shape=batch_shape + (prod(descriptor.operands[-1][i]),), like=outputs[0], ) for i in range(descriptor.operands[-1].num_segments) @@ -237,9 +279,9 @@ def _tensor_product_fx( for operand in descriptor.operands[:num_inputs] ] graphmod = opt_einsum_fx.optimize_einsums_full(graphmod, example_inputs) - else: + elif num_inputs == 0: - class _no_input_or_no_paths(torch.nn.Module): + class _no_input(torch.nn.Module): def __init__(self, descriptor: stp.SegmentedTensorProduct): super().__init__() @@ -251,12 +293,9 @@ def __init__(self, descriptor: stp.SegmentedTensorProduct): ), ) - def forward(self, *args): - shape = torch.broadcast_shapes(*[arg.shape[:-1] for arg in args]) + def forward(self): output = torch.zeros( - shape + (descriptor.operands[-1].size,), - device=device, - dtype=math_dtype, + (descriptor.operands[-1].size,), device=device, dtype=math_dtype ) for pid in range(descriptor.num_paths): output += torch.einsum( @@ -267,38 +306,103 @@ def forward(self, *args): ) return output - graphmod = _no_input_or_no_paths(descriptor) + graphmod = _no_input(descriptor) + + else: + raise NotImplementedError( + "No FX implementation for empty paths and non-empty inputs" + ) return _Wrapper(graphmod, descriptor) +class _Caller(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.module = module + + +class _NoArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module() + + +class _OneArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0]) + + +class _TwoArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1]) + + +class _ThreeArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2]) + + +class _FourArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2], args[3]) + + +class _FiveArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2], args[3], args[4]) + + +class _SixArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module(args[0], args[1], args[2], args[3], args[4], args[5]) + + +class _SevenArgCaller(_Caller): + def forward(self, args: List[torch.Tensor]): + return self.module( + args[0], args[1], args[2], args[3], args[4], args[5], args[6] + ) + + +CALL_DISPATCHERS = [ + _NoArgCaller, + _OneArgCaller, + _TwoArgCaller, + _ThreeArgCaller, + _FourArgCaller, + _FiveArgCaller, + _SixArgCaller, + _SevenArgCaller, +] + + class _Wrapper(torch.nn.Module): def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProduct): super().__init__() - self.module = module + self.module = CALL_DISPATCHERS[descriptor.num_operands - 1](module) self.descriptor = descriptor - def forward(self, args): - for oid, arg in enumerate(args): - torch._assert( - arg.shape[-1] == self.descriptor.operands[oid].size, - "input shape[-1] does not match operand size", - ) + def forward(self, args: List[torch.Tensor]): + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + for oid, arg in enumerate(args): + torch._assert( + arg.shape[-1] == self.descriptor.operands[oid].size, + "input shape[-1] does not match operand size", + ) - shape = torch.broadcast_shapes(*[arg.shape[:-1] for arg in args]) + shape = broadcast_shapes([arg.shape[:-1] for arg in args]) args = [ ( arg.expand(shape + (arg.shape[-1],)).reshape( - (math.prod(shape), arg.shape[-1]) + (prod(shape), arg.shape[-1]) ) - if math.prod(arg.shape[:-1]) > 1 - else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1])) + if prod(arg.shape[:-1]) > 1 + else arg.reshape((prod(arg.shape[:-1]), arg.shape[-1])) ) for arg in args ] - - out = self.module(*args) + out = self.module(args) return out.reshape(shape + (out.shape[-1],)) @@ -390,13 +494,13 @@ def _tensor_product_cuda( return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype) -def _reshape(x: torch.Tensor, leading_shape: tuple[int, ...]) -> torch.Tensor: +def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor: # Make x have shape (Z, x.shape[-1]) or (x.shape[-1],) - if math.prod(leading_shape) > 1 and math.prod(x.shape[:-1]) == 1: + if prod(leading_shape) > 1 and prod(x.shape[:-1]) == 1: return x.reshape((x.shape[-1],)) else: return x.expand(leading_shape + (x.shape[-1],)).reshape( - (math.prod(leading_shape), x.shape[-1]) + (prod(leading_shape), x.shape[-1]) ) @@ -434,24 +538,21 @@ def __init__( ).to(device=device) def __repr__(self) -> str: - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"FusedTensorProductOp3({self.descriptor} (output last operand))" - def forward( - self, - x0: torch.Tensor, - x1: torch.Tensor, - ) -> torch.Tensor: - x0, x1 = self._perm(x0, x1) + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + x0, x1 = self._perm(inputs[0], inputs[1]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) - logger.debug( - f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" - ) + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + logger.debug( + f"Calling FusedTensorProductOp3: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" + ) out = self._f(x0, x1) @@ -492,34 +593,30 @@ def __init__( ).to(device=device) def __repr__(self) -> str: - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"FusedTensorProductOp4({self.descriptor} (output last operand))" - def forward( - self, - x0: torch.Tensor, - x1: torch.Tensor, - x2: torch.Tensor, - ) -> torch.Tensor: - x0, x1, x2 = self._perm(x0, x1, x2) + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + x0, x1, x2 = self._perm(inputs[0], inputs[1], inputs[2]) assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) - logger.debug( - f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" - ) + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + logger.debug( + f"Calling FusedTensorProductOp4: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" + ) out = self._f(x0, x1, x2) return out.reshape(shape + (out.shape[-1],)) -class TensorProductUniform3x1d(torch.nn.Module): +class TensorProductUniform1d(torch.nn.Module): def __init__( self, descriptor: stp.SegmentedTensorProduct, @@ -545,14 +642,17 @@ def __init__( math_dtype=math_dtype, ).to(device=device) + +class TensorProductUniform3x1d(TensorProductUniform1d): def __repr__(self): - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"TensorProductUniform3x1d({self.descriptor} (output last operand))" - def forward(self, x0, x1): + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + x0, x1 = inputs assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) @@ -561,50 +661,27 @@ def forward(self, x0, x1): if x1.ndim == 1: x1 = x1.unsqueeze(0) - logger.debug( - f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" - ) + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + logger.debug( + f"Calling TensorProductUniform3x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}" + ) - out = self._f(x0, x1) + out = self._f(x0, x1, x0) return out.reshape(shape + (out.shape[-1],)) -class TensorProductUniform4x1d(torch.nn.Module): - def __init__( - self, - descriptor: stp.SegmentedTensorProduct, - device: Optional[torch.device], - math_dtype: torch.dtype, - ): - super().__init__() - import cuequivariance_ops_torch as ops - - self.descriptor = descriptor - - assert len(descriptor.subscripts.modes()) == 1 - assert descriptor.all_same_segment_shape() - assert descriptor.coefficient_subscripts == "" - u = next(iter(descriptor.get_dims(descriptor.subscripts.modes()[0]))) - - self._f = ops.TensorProductUniform1d( - operand_dim=[ope.ndim for ope in descriptor.operands], - operand_extent=u, - operand_num_segments=[ope.num_segments for ope in descriptor.operands], - path_indices=[path.indices for path in descriptor.paths], - path_coefficients=[float(path.coefficients) for path in descriptor.paths], - math_dtype=math_dtype, - ).to(device=device) - +class TensorProductUniform4x1d(TensorProductUniform1d): def __repr__(self): - return f"TensorProductCUDA({self.descriptor} (output last operand))" + return f"TensorProductUniform4x1d({self.descriptor} (output last operand))" - def forward(self, x0, x1, x2): + def forward(self, inputs: List[torch.Tensor]): + x0, x1, x2 = inputs assert x0.ndim >= 1, x0.ndim assert x1.ndim >= 1, x1.ndim assert x2.ndim >= 1, x2.ndim - shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]) + shape = broadcast_shapes([x0.shape[:-1], x1.shape[:-1], x2.shape[:-1]]) x0 = _reshape(x0, shape) x1 = _reshape(x1, shape) x2 = _reshape(x2, shape) @@ -616,9 +693,10 @@ def forward(self, x0, x1, x2): if x2.ndim == 1: x2 = x2.unsqueeze(0) - logger.debug( - f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" - ) + if not torch.jit.is_scripting() and not torch.compiler.is_compiling(): + logger.debug( + f"Calling TensorProductUniform4x1d: {self.descriptor}, input shapes: {x0.shape}, {x1.shape}, {x2.shape}" + ) out = self._f(x0, x1, x2) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index c8709b2..b23777a 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -36,19 +36,24 @@ def __init__( source: cue.IrrepsLayout, target: cue.IrrepsLayout, device: Optional[torch.device] = None, + use_fallback: Optional[bool] = False, ): super().__init__() if (source, target) == (cue.mul_ir, cue.ir_mul): self.f = TransposeSegments( - [(mul, ir.dim) for mul, ir in irreps], device=device + [(mul, ir.dim) for mul, ir in irreps], + device=device, + use_fallback=use_fallback, ) elif (source, target) == (cue.ir_mul, cue.mul_ir): self.f = TransposeSegments( - [(ir.dim, mul) for mul, ir in irreps], device=device + [(ir.dim, mul) for mul, ir in irreps], + device=device, + use_fallback=use_fallback, ) else: - self.f = _Identity() + self.f = torch.nn.Identity() self.source, self.target = source, target @@ -58,9 +63,7 @@ def __init__( def __repr__(self): return f"TransposeIrrepsLayout({self.source} -> {self.target})" - def forward( - self, x: torch.Tensor, *, use_fallback: Optional[bool] = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: r""" Perform the transposition. @@ -74,41 +77,42 @@ def forward( torch.Tensor: The transposed tensor. """ - return self.f(x, use_fallback=use_fallback) - - -class _Identity(torch.nn.Module): - def forward(self, x: torch.Tensor, **kwargs): - return x + return self.f(x) class TransposeSegments(torch.nn.Module): def __init__( - self, segments: list[tuple[int, int]], device: Optional[torch.device] = None + self, + segments: list[tuple[int, int]], + device: Optional[torch.device] = None, + use_fallback: Optional[bool] = False, ): super().__init__() info = _transpose_info(segments, device=device) + self.f = None if info is not None: - try: - import cuequivariance_ops_torch # noqa - except ImportError: - self.f_cuda = None - else: - self.f_cuda = _transpose(info).to(device=device) - - self.f = _transpose_segments_fx(segments).to(device=device) + if use_fallback is False or use_fallback is None: + try: + import cuequivariance_ops_torch # noqa: F401 + except ImportError: + pass + else: + self.f = _transpose(info).to(device=device) + + if use_fallback is False and self.f is None: + raise RuntimeError("CUDA kernel not available for TransposeSegments.") + + if self.f is None: + self.f = _transpose_segments_fx(segments).to(device=device) else: - self.f_cuda = torch.nn.Identity() self.f = torch.nn.Identity() def __repr__(self): return "TransposeSegments()" - def forward( - self, x: torch.Tensor, *, use_fallback: Optional[bool] = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Perform the transposition of the input tensor using either a CUDA kernel or a PyTorch fallback. @@ -131,19 +135,6 @@ def forward( RuntimeError If `use_fallback` is `False` and a CUDA kernel is not available or the input is not on CUDA. """ - if ( - x.device.type == "cuda" - and self.f_cuda is not None - and (use_fallback is not True) - ): - return self.f_cuda(x) - - if use_fallback is False: - if self.f_cuda is not None: - raise RuntimeError("CUDA kernel available but input is not on CUDA") - else: - raise RuntimeError("No CUDA kernel available") - return self.f(x) @@ -184,12 +175,16 @@ def _transpose_info( return torch.IntTensor(info).to(device=device) +try: + from cuequivariance_ops_torch import segmented_transpose +except ImportError: + pass + + class _transpose(torch.nn.Module): def __init__(self, info: torch.IntTensor): super().__init__() self.register_buffer("_info", info, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: - from cuequivariance_ops_torch import segmented_transpose - return segmented_transpose(x, self._info, True) diff --git a/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py b/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py index 198c6d8..96602a3 100644 --- a/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py +++ b/cuequivariance_torch/tests/layers/tp_conv_fully_connected_test.py @@ -21,7 +21,7 @@ import cuequivariance_torch as cuet from cuequivariance_torch.layers.tp_conv_fully_connected import scatter_reduce -device = torch.device("cuda:0") +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") @pytest.mark.parametrize("layout", [cue.mul_ir, cue.ir_mul]) @@ -133,7 +133,6 @@ def D(irreps, axis, angle): @pytest.mark.parametrize("reduce", ["sum", "mean", "prod", "amax", "amin"]) def test_scatter_reduce(reduce: str): - device = torch.device("cuda") src = torch.Tensor([3, 1, 0, 1, 1, 2]) index = torch.Tensor([0, 1, 2, 2, 3, 1]) @@ -153,7 +152,6 @@ def test_scatter_reduce(reduce: str): def test_scatter_reduce_empty(): - device = torch.device("cuda") src, index = torch.empty((0, 41)), torch.empty((0,)) src = src.to(device) index = index.to(device) diff --git a/cuequivariance_torch/tests/operations/linear_test.py b/cuequivariance_torch/tests/operations/linear_test.py index 8eb7718..2b78ff1 100644 --- a/cuequivariance_torch/tests/operations/linear_test.py +++ b/cuequivariance_torch/tests/operations/linear_test.py @@ -20,6 +20,8 @@ import cuequivariance as cue import cuequivariance_torch as cuet +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + list_of_irreps = [ cue.Irreps("SU2", "3x1/2 + 4x1"), cue.Irreps("SU2", "2x1/2 + 5x1 + 2x1/2"), @@ -37,23 +39,39 @@ def test_linear_fwd( layout: cue.IrrepsLayout, shared_weights: bool, ): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(0) linear = cuet.Linear( irreps_in, irreps_out, layout=layout, shared_weights=shared_weights, - device="cuda", + device=device, + dtype=torch.float64, + use_fallback=False, + ) + + torch.manual_seed(0) + linear_fx = cuet.Linear( + irreps_in, + irreps_out, + layout=layout, + shared_weights=shared_weights, + device=device, dtype=torch.float64, + use_fallback=True, ) x = torch.randn(10, irreps_in.dim, dtype=torch.float64).cuda() if shared_weights: y = linear(x) - y_fx = linear(x, use_fallback=True) + y_fx = linear_fx(x) else: w = torch.randn(10, linear.weight_numel, dtype=torch.float64).cuda() y = linear(x, w) - y_fx = linear(x, w, use_fallback=True) + y_fx = linear_fx(x, w) assert y.shape == (10, irreps_out.dim) @@ -70,31 +88,36 @@ def test_linear_bwd_bwd( layout: cue.IrrepsLayout, shared_weights: bool, ): - linear = cuet.Linear( - irreps_in, - irreps_out, - layout=layout, - shared_weights=shared_weights, - device="cuda", - dtype=torch.float64, - ) + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") outputs = dict() for use_fallback in [True, False]: + torch.manual_seed(0) + linear = cuet.Linear( + irreps_in, + irreps_out, + layout=layout, + shared_weights=shared_weights, + device=device, + dtype=torch.float64, + use_fallback=use_fallback, + ) + # reset the seed to ensure the same initialization torch.manual_seed(0) x = torch.randn( - 10, irreps_in.dim, requires_grad=True, device="cuda", dtype=torch.float64 + 10, irreps_in.dim, requires_grad=True, device=device, dtype=torch.float64 ) if shared_weights: - y = linear(x, use_fallback=use_fallback) + y = linear(x) else: w = torch.randn( 10, linear.weight_numel, requires_grad=True, dtype=torch.float64 ).cuda() - y = linear(x, w, use_fallback=use_fallback) + y = linear(x, w) (grad,) = torch.autograd.grad( y.pow(2).sum(), @@ -143,6 +166,6 @@ def test_linear_copy( irreps_out, layout=layout, shared_weights=shared_weights, - ).cuda() + ).to(device) copy.deepcopy(linear) diff --git a/cuequivariance_torch/tests/operations/rotation_test.py b/cuequivariance_torch/tests/operations/rotation_test.py index a73c68b..86d0230 100644 --- a/cuequivariance_torch/tests/operations/rotation_test.py +++ b/cuequivariance_torch/tests/operations/rotation_test.py @@ -17,16 +17,18 @@ import cuequivariance as cue import cuequivariance_torch as cuet +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def test_rotation(): irreps = cue.Irreps("SO3", "3x0 + 1 + 0 + 4x2 + 4") - alpha = torch.tensor(0.3).cuda() - beta = torch.tensor(0.4).cuda() - gamma = torch.tensor(-0.5).cuda() + alpha = torch.tensor(0.3).to(device) + beta = torch.tensor(0.4).to(device) + gamma = torch.tensor(-0.5).to(device) - rot = cuet.Rotation(irreps, layout=cue.ir_mul).cuda() + rot = cuet.Rotation(irreps, layout=cue.ir_mul).to(device) - x = torch.randn(10, irreps.dim).cuda() + x = torch.randn(10, irreps.dim).to(device) rx = rot(gamma, beta, alpha, x) x_ = rot(-alpha, -beta, -gamma, rx) diff --git a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py index 6db1570..955ee87 100644 --- a/cuequivariance_torch/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_torch/tests/operations/spherical_harmonics_test.py @@ -19,6 +19,8 @@ import cuequivariance as cue import cuequivariance_torch as cuet +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + @pytest.mark.parametrize( "dtype, tol", @@ -26,15 +28,15 @@ ) @pytest.mark.parametrize("ell", [1, 2, 3]) def test_spherical_harmonics(ell: int, dtype, tol): - vec = torch.randn(3, dtype=dtype) + vec = torch.randn(3, dtype=dtype, device=device) axis = np.random.randn(3) angle = np.random.rand() scale = 1.3 yl = cuet.spherical_harmonics([ell], vec, False) - R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype) - Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype) + R = torch.from_numpy(cue.SO3(1).rotation(axis, angle)).to(dtype).to(device) + Rl = torch.from_numpy(cue.SO3(ell).rotation(axis, angle)).to(dtype).to(device) yl1 = cuet.spherical_harmonics([ell], scale * R @ vec, False) yl2 = scale**ell * Rl @ yl @@ -43,7 +45,7 @@ def test_spherical_harmonics(ell: int, dtype, tol): def test_spherical_harmonics_full(): - vec = torch.randn(3) + vec = torch.randn(3, device=device) ls = [0, 1, 2, 3] yl = cuet.spherical_harmonics(ls, vec, False) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 09110c5..80a4065 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -22,6 +22,8 @@ import cuequivariance_torch as cuet from cuequivariance.experimental.e3nn import O3_e3nn +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + USE_TF32 = False torch.backends.cuda.matmul.allow_tf32 = USE_TF32 torch.backends.cudnn.allow_tf32 = USE_TF32 @@ -30,7 +32,7 @@ @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("original_mace", [True, False]) -@pytest.mark.parametrize("batch", [0, 32]) +@pytest.mark.parametrize("batch", [1, 32]) def test_symmetric_contraction(dtype, layout, original_mace, batch): mul = 64 irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e") @@ -45,12 +47,12 @@ def test_symmetric_contraction(dtype, layout, original_mace, batch): layout_out=layout, dtype=dtype, math_dtype=dtype, - device="cuda", + device=device, original_mace=original_mace, ) - x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda() - indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda() + x = torch.randn((batch, irreps_in.dim), dtype=dtype).to(device) + indices = torch.randint(0, 5, (batch,), dtype=torch.int32).to(device) out = m(x, indices) assert out.shape == (batch, irreps_out.dim) @@ -58,7 +60,7 @@ def test_symmetric_contraction(dtype, layout, original_mace, batch): def from64(shape: tuple[int, ...], data: str) -> torch.Tensor: x = np.frombuffer(base64.b64decode(data), dtype=np.float32).reshape(shape) - return torch.from_numpy(x.copy()).cuda() + return torch.from_numpy(x.copy()).to(device) def test_mace_compatibility(): @@ -76,7 +78,7 @@ def test_mace_compatibility(): irreps_in = cue.Irreps(O3_e3nn, "0e + 1o + 2e") irreps_out = cue.Irreps(O3_e3nn, "0e + 1o") - i = (torch.arange(3) % num_elements).cuda() + i = (torch.arange(3) % num_elements).to(device) x = from64( (3, 36), "mHgaP1zHTz5kdhs/3ygQvwzZf77dhoU8+iP+PzzRRD8L9CY+qi9Fv5aiBz/sGJG/xwaev+5w4b2Mbg8+1jDOP4/cwj9rt/u/FedUP7H6qD4y9LM+i7yvPhcifz8coHE/Vkk1PwK0hb/BNig+GF4gP1FNaD94Uj++d+1qPtkrYD8m8o6/9zK9PihGBz9M6Ne9XgCXP/r6bzxTXJO/glIsQPQlDL/fN5w7VeeKP4iYlD/9Msa/GF/cvg+2gz/oRJ6/0Te4P7g+oz8YQ6g+k0q0vN8WEr41/u0/sa55PmAhvD9FZZw/ICJtvyxFkz+zOAq/8JtNPztZX74E9hK/xCdqv4+0Rz9Ah/g+5vmDv6mLL7+M5DI/xgP3PhWEnj5ZmZ0+DBkXwPa12D1mVPo9rDdWP4DkRD+L85Y9EJ01P+8Hiz6gxSM7/eoPwOQOtr8gjge+NBEYPrmg5L2XpO8/F2tCvjEyWL8gjLw+UOIuP5bhPr9qRvM+ADa5v3rqLLwSr/8+PbZhP4tn675SWVm/SMC1P5h/0r0D8v2/CNS7Pza7SL8PqJG+DsKCOpTKoT+xnLg/", @@ -95,7 +97,7 @@ def test_mace_compatibility(): layout_in=cue.ir_mul, layout_out=cue.mul_ir, original_mace=True, - device="cuda", + device=device, dtype=torch.float32, math_dtype=torch.float64, ) diff --git a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py index 9540c73..d3e7cdd 100644 --- a/cuequivariance_torch/tests/operations/tp_channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/tp_channel_wise_test.py @@ -19,6 +19,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + list_of_irreps = [ cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), @@ -31,7 +33,7 @@ @pytest.mark.parametrize("irreps3", list_of_irreps) @pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir]) @pytest.mark.parametrize("use_fallback", [False, True]) -@pytest.mark.parametrize("batch", [0, 32]) +@pytest.mark.parametrize("batch", [1, 32]) def test_channel_wise( irreps1: cue.Irreps, irreps2: cue.Irreps, @@ -40,61 +42,70 @@ def test_channel_wise( use_fallback: bool, batch: int, ): - m = cuet.ChannelWiseTensorProduct( + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m1 = cuet.ChannelWiseTensorProduct( irreps1, irreps2, irreps3, shared_weights=True, internal_weights=True, layout=layout, - device="cuda", + device=device, dtype=torch.float64, + use_fallback=use_fallback, ) + x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).to(device) + x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).to(device) - x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda() - - out1 = m(x1, x2, use_fallback=use_fallback) + out1 = m1(x1, x2) d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d d = d.squeeze_modes("v") assert d.subscripts == "u,iu,j,ku+ijk" if layout == cue.mul_ir: d = d.add_or_transpose_modes("u,ui,j,uk+ijk") - mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() - out2 = mfx([m.weight, x1, x2], use_fallback=True) + m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) + out2 = m2([m1.weight, x1, x2]) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) def test_channel_wise_bwd_bwd(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + irreps1 = cue.Irreps("SO3", "2x0 + 3x1") irreps2 = cue.Irreps("SO3", "0 + 1") irreps3 = cue.Irreps("SO3", "0 + 1") - m = cuet.ChannelWiseTensorProduct( - irreps1, - irreps2, - irreps3, - shared_weights=True, - internal_weights=False, - layout=cue.ir_mul, - device="cuda", - dtype=torch.float64, - ) - x1 = torch.randn( - 32, irreps1.dim, device="cuda", requires_grad=True, dtype=torch.float64 + 32, irreps1.dim, device=device, requires_grad=True, dtype=torch.float64 ) x2 = torch.randn( - 32, irreps2.dim, device="cuda", requires_grad=True, dtype=torch.float64 - ) - w = torch.randn( - m.weight_numel, device="cuda", requires_grad=True, dtype=torch.float64 + 32, irreps2.dim, device=device, requires_grad=True, dtype=torch.float64 ) outputs = {} for use_fallback in [True, False]: + m = cuet.ChannelWiseTensorProduct( + irreps1, + irreps2, + irreps3, + shared_weights=True, + internal_weights=False, + layout=cue.ir_mul, + device="cuda", + dtype=torch.float64, + use_fallback=use_fallback, + ) + + torch.manual_seed(0) + w = torch.randn( + m.weight_numel, device="cuda", requires_grad=True, dtype=torch.float64 + ) + (grad1, grad2, grad3) = torch.autograd.grad( m(x1, x2, w).pow(2).sum(), (x1, x2, w), create_graph=True ) diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index 4e197fd..832904b 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -19,6 +19,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + list_of_irreps = [ cue.Irreps("O3", "4x0e + 4x1o"), cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"), @@ -38,44 +40,52 @@ def test_fully_connected( layout: cue.IrrepsLayout, use_fallback: bool, ): - m = cuet.FullyConnectedTensorProduct( + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m1 = cuet.FullyConnectedTensorProduct( irreps1, irreps2, irreps3, shared_weights=True, internal_weights=True, layout=layout, - device="cuda", + device=device, dtype=torch.float64, + use_fallback=use_fallback, ) - x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda() - x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda() + x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).to(device) + x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).to(device) - out1 = m(x1, x2, use_fallback=use_fallback) + out1 = m1(x1, x2) d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d if layout == cue.mul_ir: d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") - mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda() - out2 = mfx( - [m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], - use_fallback=True, + m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) + out2 = m2( + [m1.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)], ).to(out1.dtype) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) -def test_compile(): +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_compile(use_fallback: bool): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + m = cuet.FullyConnectedTensorProduct( irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"), irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"), irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), layout=cue.mul_ir, - optimize_fallback=False, + device=device, + use_fallback=use_fallback, ) m_compile = torch.compile(m, fullgraph=True) - input1 = torch.randn(100, m.irreps_in1.dim) - input2 = torch.randn(100, m.irreps_in2.dim) + input1 = torch.randn(100, m.irreps_in1.dim, device=device) + input2 = torch.randn(100, m.irreps_in2.dim, device=device) m_compile(input1, input2) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 04d3aef..043e387 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -12,15 +12,209 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import timeit import pytest import torch +import torch._dynamo import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors +torch._dynamo.config.cache_size_limit = 100 + + +try: + import cuequivariance_ops_torch.onnx # noqa: F401 + import onnx # noqa: F401 + import onnxruntime # noqa: F401 + import onnxscript # noqa: F401 + from cuequivariance_ops_torch.tensorrt import register_plugins + + ONNX_AVAILABLE = True +except Exception: + ONNX_AVAILABLE = False + + +try: + import torch_tensorrt + + TORCH_TRT_AVAILABLE = True +except Exception: + TORCH_TRT_AVAILABLE = False + + +def verify_onnx(module, onnx_module, inputs, dtype): + if dtype != torch.float32: + pytest.skip("onnxrt only checked for float32") + from onnxruntime import SessionOptions + from onnxruntime_extensions import get_library_path + from torch.onnx.verification import ( + VerificationOptions, + _compare_onnx_pytorch_model, + ) + + original_init = SessionOptions.__init__ + + def new_init(self): + original_init(self) + try: + self.register_custom_ops_library(get_library_path()) + except Exception: + pass + + SessionOptions.__init__ = new_init + _compare_onnx_pytorch_model( + module, onnx_module, tuple(inputs), None, None, VerificationOptions() + ) + SessionOptions.__init__ = original_init + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def verify_trt(module, onnx_module, inputs, dtype): + import tensorrt + from pkg_resources import parse_version + + if parse_version(tensorrt.__version__) < parse_version("10.3.0"): + pytest.skip("TRT < 10.3.0 is not supported!") + if dtype == torch.float64: + pytest.skip("TRT does not support float64") + + from onnxruntime import InferenceSession, SessionOptions + from onnxruntime_extensions import get_library_path + from polygraphy.backend.onnxrt import OnnxrtRunner + from polygraphy.backend.trt import ( + CreateConfig, + TrtRunner, + engine_from_network, + network_from_onnx_path, + ) + from polygraphy.comparator import Comparator, DataLoader + + register_plugins() + + network = network_from_onnx_path(onnx_module) + trt_engine = engine_from_network(network, config=CreateConfig()) + + if dtype != torch.float32: + pytest.skip("Comparator only supports float32") + + # Create runners for ONNX and TRT models + trt_runner = TrtRunner(trt_engine) + + options = SessionOptions() + options.register_custom_ops_library(get_library_path()) + onnx_runner = OnnxrtRunner(InferenceSession(onnx_module, sess_options=options)) + + results = Comparator.run([trt_runner, onnx_runner], data_loader=DataLoader()) + Comparator.compare_accuracy(results) + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def module_with_mode( + mode, + module, + inputs, + math_dtype, + tmp_path, + grad_modes=["eager", "compile", "jit", "export"], +): + if isinstance(inputs[0], list): + dtype = inputs[0][0].dtype + else: + dtype = inputs[0].dtype + if mode in ["trt", "torch_trt", "onnx", "onnx_dynamo", "export"]: + if not ONNX_AVAILABLE: + pytest.skip("ONNX not available!") + if dtype == torch.float64 or math_dtype == torch.float64: + pytest.skip("TRT/ORT do not support float64") + + with torch.set_grad_enabled(mode in grad_modes): + if mode == "compile": + import sys + + if sys.version_info.major == 3 and sys.version_info.minor >= 12: + pytest.skip("torch dynamo needs cpy <= 3.11") + module = torch.compile(module) + elif mode == "fx": + module = torch.fx.symbolic_trace(module) + elif mode == "jit": + module = torch.jit.trace(module, inputs) + fname = os.path.join(tmp_path, "test.ts") + torch.jit.save(module, fname) + module = torch.jit.load(fname) + elif mode == "export": + exp_program = torch.export.export(module, tuple(inputs)) + fname = os.path.join(tmp_path, "test.pt2") + torch.export.save(exp_program, fname) + del exp_program + module = torch.export.load(fname).module() + elif mode == "torch_trt": + if not TORCH_TRT_AVAILABLE: + pytest.skip("torch_tensorrt is not installed!") + register_plugins() + exp_program = torch_tensorrt.dynamo.trace(module, inputs) + module = torch_tensorrt.dynamo.compile( + exp_program, + inputs=inputs, + require_full_compilation=True, + min_block_size=1, + enabled_precisions={torch.float32, dtype}, + # dryrun=True + ) + elif mode == "onnx" or mode == "trt": + try: + onnx_path = os.path.join(tmp_path, "test.onnx") + torch.onnx.export( + module, tuple(inputs), onnx_path, opset_version=17, verbose=False + ) + if mode == "trt": + verify_trt(module, onnx_path, inputs, dtype) + else: + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX/TRT is not available") + + elif mode == "onnx_dynamo": + try: + from cuequivariance_ops_torch.onnx import ( + cuequivariance_ops_torch_onnx_registry, + ) + + export_options = torch.onnx.ExportOptions( + onnx_registry=cuequivariance_ops_torch_onnx_registry + ) + onnx_program = torch.onnx.dynamo_export( + module, *inputs, export_options=export_options + ) + onnx_path = os.path.join(tmp_path, "test.onnx") + onnx_program.save(onnx_path) + verify_onnx(module, onnx_path, inputs, dtype) + except ImportError: + pytest.skip("ONNX is not available") + elif mode == "eager": + pass + else: + raise ValueError(f"No such mode: {mode}") + + torch.cuda.synchronize() + torch.cuda.empty_cache() + + return module + + +def maybe_detach_and_to(tensor, *args, **kwargs): + if tensor is not None: + return tensor.clone().detach().to(*args, **kwargs) + return None + + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def make_descriptors(): # This ETP will trigger the fusedTP kernel @@ -54,7 +248,7 @@ def make_descriptors(): (torch.float32, torch.float32), (torch.float64, torch.float64), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings1 += [ (torch.float16, torch.float32), (torch.bfloat16, torch.float32), @@ -68,13 +262,23 @@ def test_performance_cuda_vs_fx( dtype: torch.dtype, math_dtype: torch.dtype, ): - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, device=device, math_dtype=math_dtype, + use_fallback=False, + ) + + m1 = cuet.EquivariantTensorProduct( + e, + layout=cue.ir_mul, + device=device, + math_dtype=math_dtype, + use_fallback=True, optimize_fallback=True, ) @@ -84,15 +288,22 @@ def test_performance_cuda_vs_fx( ] for _ in range(10): - m(inputs, use_fallback=False) - m(inputs, use_fallback=True) + m(inputs) + m1(inputs) + torch.cuda.synchronize() - def f(ufb: bool): - m(inputs, use_fallback=ufb) - torch.cuda.synchronize() + def f(): + ret = m(inputs) + ret = torch.sum(ret) + return ret - t0 = timeit.Timer(lambda: f(False)).timeit(number=10) - t1 = timeit.Timer(lambda: f(True)).timeit(number=10) + def f1(): + ret = m1(inputs) + ret = torch.sum(ret) + return ret + + t0 = timeit.Timer(f).timeit(number=10) + t1 = timeit.Timer(f1).timeit(number=10) assert t0 < t1 @@ -102,7 +313,7 @@ def f(ufb: bool): (torch.float64, torch.float32, 1e-5, 1e-6), (torch.float64, torch.float64, 1e-12, 0), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings2 += [ (torch.float16, torch.float32, 1, 0.2), (torch.bfloat16, torch.float32, 1, 0.2), @@ -118,39 +329,109 @@ def test_precision_cuda_vs_fx( atol: float, rtol: float, ): - device = torch.device("cuda:0") + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") inputs = [ torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) for inp in e.inputs ] m = cuet.EquivariantTensorProduct( - e, - layout=cue.ir_mul, - device=device, - math_dtype=math_dtype, + e, layout=cue.ir_mul, device=device, math_dtype=math_dtype, use_fallback=False ) - y0 = m(inputs, use_fallback=False) + y0 = m(inputs) m = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, device=device, math_dtype=torch.float64, + use_fallback=True, optimize_fallback=True, ) - inputs = map(lambda x: x.to(torch.float64), inputs) - y1 = m(inputs, use_fallback=True).to(dtype) + inputs = [x.to(torch.float64) for x in inputs] + y1 = m(inputs).to(dtype) torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) -def test_compile(): - e = cue.descriptors.symmetric_contraction( - cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] +@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +def test_compile( + e: cue.EquivariantTensorProduct, + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype ) - m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False) + inputs = [ + torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) + for inp in e.inputs + ] + res = m(inputs) m_compile = torch.compile(m, fullgraph=True) - input1 = torch.randn(100, e.inputs[0].irreps.dim) - input2 = torch.randn(100, e.inputs[1].irreps.dim) - m_compile([input1, input2]) + res_script = m_compile(inputs) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +def test_script( + e: cue.EquivariantTensorProduct, + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, use_fallback=False, device=device, math_dtype=math_dtype + ) + inputs = [ + torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) + for inp in e.inputs + ] + res = m(inputs) + m_script = torch.jit.script(m) + res_script = m_script(inputs) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) + + +# export_modes = ["onnx", "onnx_dynamo", "trt", "torch_trt", "jit"] +export_modes = ["trt", "onnx"] + + +@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +@pytest.mark.parametrize("mode", export_modes) +def test_export( + e: cue.EquivariantTensorProduct, + dtype: torch.dtype, + math_dtype: torch.dtype, + atol: float, + rtol: float, + mode: str, + tmp_path, +): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + m = cuet.EquivariantTensorProduct( + e, layout=cue.mul_ir, math_dtype=math_dtype, use_fallback=False, device=device + ) + inputs = [ + torch.randn((1024, inp.irreps.dim), device=device, dtype=dtype) + for inp in e.inputs + ] + res = m(inputs) + m_script = module_with_mode(mode, m, [inputs], math_dtype, tmp_path) + res_script = m_script(inputs) + torch.testing.assert_close(res, res_script, atol=atol, rtol=rtol) diff --git a/cuequivariance_torch/tests/primitives/script_test.py b/cuequivariance_torch/tests/primitives/script_test.py new file mode 100644 index 0000000..4706bff --- /dev/null +++ b/cuequivariance_torch/tests/primitives/script_test.py @@ -0,0 +1,125 @@ +import pytest +import torch + +import cuequivariance as cue +from cuequivariance_torch.primitives.symmetric_tensor_product import ( + CUDAKernel as SymmetricTensorProduct, +) +from cuequivariance_torch.primitives.tensor_product import ( + FusedTensorProductOp3, + FusedTensorProductOp4, + TensorProductUniform3x1d, + TensorProductUniform4x1d, +) + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + + +def test_script_symmetric_contraction(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + ds = cue.descriptors.symmetric_contraction( + 32 * cue.Irreps("SO3", "0 + 1"), 32 * cue.Irreps("SO3", "0 + 1"), [1, 2, 3] + ).ds + + batch = 12 + x0 = torch.randn(3, ds[0].operands[0].size, device=device, dtype=torch.float32) + i0 = torch.zeros(batch, device=device, dtype=torch.int32) + x1 = torch.randn(batch, ds[0].operands[1].size, device=device, dtype=torch.float32) + + module = SymmetricTensorProduct(ds, device, torch.float32) + module = torch.jit.script(module) + + assert module(x0, i0, x1).shape == (batch, ds[0].operands[-1].size) + + +def test_script_fused_tp_3(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + d = ( + cue.descriptors.full_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + + module = FusedTensorProductOp3(d, (0, 1), device, torch.float32) + module = torch.jit.script(module) + + assert module([x0, x1]).shape == (batch, d.operands[2].size) + + +def test_script_fused_tp_4(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + d = ( + cue.descriptors.fully_connected_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + .permute_operands([1, 2, 0, 3]) + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + x2 = torch.randn(batch, d.operands[2].size, device=device, dtype=torch.float32) + + module = FusedTensorProductOp4(d, (0, 1, 2), device, torch.float32) + module = torch.jit.script(module) + + assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) + + +def test_script_uniform_tp_3(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + d = ( + cue.descriptors.full_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + + module = TensorProductUniform3x1d(d, device, torch.float32) + module = torch.jit.script(module) + + assert module([x0, x1]).shape == (batch, d.operands[2].size) + + +def test_script_uniform_tp_4(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + d = ( + cue.descriptors.channelwise_tensor_product( + cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") + ) + .d.flatten_coefficient_modes() + .squeeze_modes("v") + ) + + batch = 12 + x0 = torch.randn(batch, d.operands[0].size, device=device, dtype=torch.float32) + x1 = torch.randn(batch, d.operands[1].size, device=device, dtype=torch.float32) + x2 = torch.randn(batch, d.operands[2].size, device=device, dtype=torch.float32) + + module = TensorProductUniform4x1d(d, device, torch.float32) + module = torch.jit.script(module) + + assert module([x0, x1, x2]).shape == (batch, d.operands[3].size) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 95dfc4d..7858576 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -20,6 +20,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def make_descriptors(): yield descriptors.symmetric_contraction( @@ -47,7 +49,7 @@ def make_descriptors(): (torch.float32, torch.float64, 1e-5), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings1 += [ (torch.float16, torch.float32, 1.0), (torch.float16, torch.float64, 0.1), @@ -58,14 +60,22 @@ def make_descriptors(): @pytest.mark.parametrize("ds", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings1) +@pytest.mark.parametrize("use_fallback", [False, True]) def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( - ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float + ds: list[stp.SegmentedTensorProduct], + dtype, + math_dtype, + tol: float, + use_fallback: bool, ): - device = torch.device("cuda:0") + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=math_dtype, device=device, optimize_fallback=False + ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) + if use_fallback is False: + m = torch.jit.script(m) x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True) i0 = torch.tensor([0, 1, 0], dtype=torch.int32, device=device) @@ -75,11 +85,15 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( x0_ = x0.clone().to(torch.float64) x1_ = x1.clone().to(torch.float64) - out1 = m(x0, i0, x1, use_fallback=False) + out1 = m(x0, i0, x1) m = cuet.IWeightedSymmetricTensorProduct( - ds, math_dtype=torch.float64, device=device, optimize_fallback=True + ds, + math_dtype=torch.float64, + device=device, + use_fallback=True, + optimize_fallback=True, ) - out2 = m(x0_, i0, x1_, use_fallback=True) + out2 = m(x0_, i0, x1_) assert out1.dtype == dtype @@ -104,7 +118,7 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( (torch.float32, torch.float64), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings2 += [ (torch.float16, torch.float32), (torch.bfloat16, torch.float32), @@ -112,28 +126,29 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( @pytest.mark.parametrize("dtype, math_dtype", settings2) -def test_math_dtype( - dtype: torch.dtype, - math_dtype: torch.dtype, -): - device = torch.device("cuda:0") +@pytest.mark.parametrize("use_fallback", [False, True]) +def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: bool): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") ds = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ).ds - m = cuet.IWeightedSymmetricTensorProduct(ds, math_dtype=math_dtype, device=device) + m = cuet.IWeightedSymmetricTensorProduct( + ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback + ) x0 = torch.randn((20, m.x0_size), dtype=dtype, device=device) i0 = torch.randint(0, m.x0_size, (1000,), dtype=torch.int32, device=device) x1 = torch.randn((i0.size(0), m.x1_size), dtype=dtype, device=device) - out1 = m(x0, i0, x1, use_fallback=False) + out1 = m(x0, i0, x1) # .to should have no effect for param in m.parameters(): assert False # no parameters m = m.to(torch.float64) - out2 = m(x0, i0, x1, use_fallback=False) + out2 = m(x0, i0, x1) assert out1.dtype == dtype assert out2.dtype == dtype diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 53e8bfc..d8c26ef 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -22,6 +22,8 @@ import cuequivariance_torch as cuet from cuequivariance import descriptors +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + def make_descriptors(): yield descriptors.fully_connected_tensor_product( @@ -80,7 +82,7 @@ def make_descriptors(): (torch.float64, torch.float64, 1e-12), ] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: settings += [ (torch.float16, torch.float32, 1.0), (torch.bfloat16, torch.float32, 1.0), @@ -89,13 +91,16 @@ def make_descriptors(): @pytest.mark.parametrize("d", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings) +@pytest.mark.parametrize("use_fallback", [False, True]) def test_primitive_tensor_product_cuda_vs_fx( d: stp.SegmentedTensorProduct, dtype: torch.dtype, math_dtype: torch.dtype, tol: float, + use_fallback: bool, ): - device = torch.device("cuda:0") + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") for batches in itertools.product([(16,), (), (4, 1)], repeat=d.num_operands - 1): inputs = [ @@ -109,14 +114,22 @@ def test_primitive_tensor_product_cuda_vs_fx( ] m = cuet.TensorProduct( - d, device=device, math_dtype=math_dtype, optimize_fallback=False + d, device=device, math_dtype=math_dtype, use_fallback=use_fallback ) - out1 = m(inputs, use_fallback=False) + if not use_fallback: + m = torch.jit.script(m) + + out1 = m(inputs) + m = cuet.TensorProduct( - d, device=device, math_dtype=torch.float64, optimize_fallback=False + d, + device=device, + math_dtype=torch.float64, + use_fallback=True, + optimize_fallback=False, ) inputs_ = [inp.clone().to(torch.float64) for inp in inputs] - out2 = m(inputs_, use_fallback=True) + out2 = m(inputs_) assert out1.shape[:-1] == torch.broadcast_shapes(*batches) assert out1.dtype == dtype diff --git a/cuequivariance_torch/tests/primitives/transpose_test.py b/cuequivariance_torch/tests/primitives/transpose_test.py index cd39cfc..31eb271 100644 --- a/cuequivariance_torch/tests/primitives/transpose_test.py +++ b/cuequivariance_torch/tests/primitives/transpose_test.py @@ -12,13 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch + import cuequivariance_torch as cuet -import pytest +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") dtypes = [torch.float32, torch.float64] -if torch.cuda.get_device_capability()[0] >= 8: +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: dtypes += [torch.float16, torch.bfloat16] @@ -33,14 +35,16 @@ def test_transpose(use_fallback: bool, dtype: torch.dtype): 10 11 10 12 12 13 11 13 """ + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") x = torch.tensor( - [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10, 11, 12, 13]], dtype=dtype - ).cuda() + [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 10, 11, 12, 13]], dtype=dtype, device=device + ) segments = [(2, 3), (2, 2)] xt = torch.tensor( - [[1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 10, 12, 11, 13]], dtype=dtype - ).cuda() + [[1.0, 4.0, 2.0, 5.0, 3.0, 6.0, 10, 12, 11, 13]], dtype=dtype, device=device + ) - m = cuet.TransposeSegments(segments).cuda() - torch.testing.assert_close(m(x, use_fallback=use_fallback), xt) + m = cuet.TransposeSegments(segments, device, use_fallback=use_fallback) + torch.testing.assert_close(m(x), xt)