Skip to content

Commit

Permalink
Add support for zero batch (#27)
Browse files Browse the repository at this point in the history
* test and quick fix for zero batch

* trigger uniform 1d in test

* satisfy linter

Signed-off-by: Mario Geiger <mgeiger@nvidia.com>

* from typing import

* determine math_dtype earlier

* warning with pip commands

* remove unused argument

* changelog

* add Fixed subtite

---------

Signed-off-by: Mario Geiger <mgeiger@nvidia.com>
  • Loading branch information
mariogeiger authored Nov 21, 2024
1 parent fef9ff1 commit 04595a8
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 34 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Latest Changes

### Fixed

- Add support for empty batch dimension in `cuequivariance-torch`.

## 0.1.0 (2024-11-18)

- Beta version of cuEquivariance released.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache
from typing import *
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -164,7 +164,7 @@ def U_matrix_real(
assert isinstance(ir_out, cue.Irrep)

if correlation == 4:
filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)])
filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) # noqa E741
else:
filter_ir_mid = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import *
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -70,9 +70,6 @@ def __init__(
optimize_fallback: Optional[bool] = None,
):
super().__init__()
cue.descriptors.fully_connected_tensor_product(
cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1")
)
if not isinstance(layout_in, tuple):
layout_in = (layout_in,) * e.num_inputs
if len(layout_in) != e.num_inputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
# limitations under the License.
import logging
import math
import warnings
from typing import *
from typing import Optional

import torch
import torch.fx

import cuequivariance.segmented_tensor_product as stp
import cuequivariance_torch as cuet
from cuequivariance import segmented_tensor_product as stp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -341,7 +339,7 @@ def forward(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
out = self.f(x1, x0, i0)
out = out.reshape(out.shape[0], -1)
out = out.reshape(out.shape[0], out.shape[1] * self.u)
return out


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import math
import warnings
from typing import *
from typing import Optional, OrderedDict, Tuple

import torch
import torch.fx
Expand Down Expand Up @@ -47,13 +47,22 @@ def __init__(
super().__init__()
self.descriptor = descriptor

if math_dtype is None:
math_dtype = torch.get_default_dtype()

try:
self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype)
except NotImplementedError as e:
logger.info(f"CUDA implementation not available: {e}")
self.f_cuda = None
except ImportError as e:
logger.warning(f"CUDA implementation not available: {e}")
logger.warning(
"Did you forget to install the CUDA version of cuequivariance-ops-torch?\n"
"Install it with one of the following commands:\n"
"pip install cuequivariance-ops-torch-cu11\n"
"pip install cuequivariance-ops-torch-cu12"
)
self.f_cuda = None

self.f_fx = _tensor_product_fx(
Expand Down Expand Up @@ -88,6 +97,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
Raises:
RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA.
"""
if any(x.numel() == 0 for x in args):
use_fallback = True # Empty tensors are not supported by the CUDA kernel

if (
args
and args[0].device.type == "cuda"
Expand All @@ -113,18 +125,14 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
def _tensor_product_fx(
descriptor: stp.SegmentedTensorProduct,
device: Optional[torch.device],
math_dtype: Optional[torch.dtype],
math_dtype: torch.dtype,
optimize_einsums: bool,
) -> torch.nn.Module:
"""
batch support of this function:
- at least one input operand should have a batch dimension (ndim=2)
- the output operand will have a batch dimension (ndim=2)
"""

if math_dtype is None:
math_dtype = torch.get_default_dtype()

descriptor = descriptor.remove_zero_paths()
descriptor = descriptor.remove_empty_segments()

Expand Down Expand Up @@ -285,7 +293,7 @@ def forward(self, *args):
(math.prod(shape), arg.shape[-1])
)
if math.prod(arg.shape[:-1]) > 1
else arg.reshape((1, arg.shape[-1]))
else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1]))
)
for arg in args
]
Expand All @@ -310,7 +318,7 @@ def _sum(tensors, *, shape=None, like=None):
def _tensor_product_cuda(
descriptor: stp.SegmentedTensorProduct,
device: Optional[torch.device],
math_dtype: Optional[torch.dtype],
math_dtype: torch.dtype,
) -> torch.nn.Module:
logger.debug(f"Starting search for a cuda kernel for {descriptor}")

Expand All @@ -323,9 +331,6 @@ def _tensor_product_cuda(
f" Got {descriptor.subscripts}."
)

if math_dtype is None:
math_dtype = torch.get_default_dtype()

if not torch.cuda.is_available():
raise NotImplementedError("CUDA is not available.")

Expand Down Expand Up @@ -438,12 +443,10 @@ def forward(
self,
x0: torch.Tensor,
x1: torch.Tensor,
b2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x0, x1 = self._perm(x0, x1)
assert x0.ndim >= 1, x0.ndim
assert x1.ndim >= 1, x1.ndim
assert b2 is None

shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1])
x0 = _reshape(x0, shape)
Expand Down Expand Up @@ -499,13 +502,11 @@ def forward(
x0: torch.Tensor,
x1: torch.Tensor,
x2: torch.Tensor,
b3: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x0, x1, x2 = self._perm(x0, x1, x2)
assert x0.ndim >= 1, x0.ndim
assert x1.ndim >= 1, x1.ndim
assert x2.ndim >= 1, x2.ndim
assert b3 is None

shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1])
x0 = _reshape(x0, shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32])
@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir])
@pytest.mark.parametrize("original_mace", [True, False])
def test_symmetric_contraction(dtype, layout, original_mace):
@pytest.mark.parametrize("batch", [0, 32])
def test_symmetric_contraction(dtype, layout, original_mace, batch):
mul = 64
irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e")
irreps_out = mul * cue.Irreps("O3", "0e + 1o")
Expand All @@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace):
original_mace=original_mace,
)

Z = 32
x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda()
indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda()
x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda()
indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda()

out = m(x, indices)
assert out.shape == (Z, irreps_out.dim)
assert out.shape == (batch, irreps_out.dim)


def from64(shape: tuple[int, ...], data: str) -> torch.Tensor:
Expand Down
8 changes: 5 additions & 3 deletions cuequivariance_torch/tests/operations/tp_channel_wise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cuequivariance import descriptors

list_of_irreps = [
cue.Irreps("O3", "4x0e + 4x1o"),
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"),
cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"),
]
Expand All @@ -31,12 +31,14 @@
@pytest.mark.parametrize("irreps3", list_of_irreps)
@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir])
@pytest.mark.parametrize("use_fallback", [False, True])
@pytest.mark.parametrize("batch", [0, 32])
def test_channel_wise(
irreps1: cue.Irreps,
irreps2: cue.Irreps,
irreps3: cue.Irreps,
layout: cue.IrrepsLayout,
use_fallback: bool,
batch: int,
):
m = cuet.ChannelWiseTensorProduct(
irreps1,
Expand All @@ -49,8 +51,8 @@ def test_channel_wise(
dtype=torch.float64,
)

x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda()
x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda()
x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda()
x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda()

out1 = m(x1, x2, use_fallback=use_fallback)

Expand Down
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Change Log
# Changelog

```{include} ../CHANGELOG.md

0 comments on commit 04595a8

Please # to comment.