Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add support for zero batch #27

Merged
merged 9 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## Latest Changes

### Fixed

- Add support for empty batch dimension in `cuequivariance-torch`.
mariogeiger marked this conversation as resolved.
Show resolved Hide resolved

## 0.1.0 (2024-11-18)

- Beta version of cuEquivariance released.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cache
from typing import *
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -164,7 +164,7 @@ def U_matrix_real(
assert isinstance(ir_out, cue.Irrep)

if correlation == 4:
filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)])
filter_ir_mid = frozenset([G(l, (-1) ** l) for l in range(11 + 1)]) # noqa E741
else:
filter_ir_mid = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import *
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -70,9 +70,6 @@ def __init__(
optimize_fallback: Optional[bool] = None,
):
super().__init__()
cue.descriptors.fully_connected_tensor_product(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is deleting this safe here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This like has probably been copy-pasted here by accident. it has no effect and has nothing to do here

cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1"), cue.Irreps("SO3", "2x1")
)
if not isinstance(layout_in, tuple):
layout_in = (layout_in,) * e.num_inputs
if len(layout_in) != e.num_inputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
# limitations under the License.
import logging
import math
import warnings
from typing import *
from typing import Optional

import torch
import torch.fx

import cuequivariance.segmented_tensor_product as stp
import cuequivariance_torch as cuet
from cuequivariance import segmented_tensor_product as stp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -341,7 +339,7 @@ def forward(
f"Calling SymmetricTensorContraction: {self.descriptors}, input shapes: {x0.shape}, {i0.shape}, {x1.shape}"
)
out = self.f(x1, x0, i0)
out = out.reshape(out.shape[0], -1)
out = out.reshape(out.shape[0], out.shape[1] * self.u)
return out


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import math
import warnings
from typing import *
from typing import Optional, OrderedDict, Tuple

import torch
import torch.fx
Expand Down Expand Up @@ -47,13 +47,22 @@ def __init__(
super().__init__()
self.descriptor = descriptor

if math_dtype is None:
math_dtype = torch.get_default_dtype()

try:
self.f_cuda = _tensor_product_cuda(descriptor, device, math_dtype)
except NotImplementedError as e:
logger.info(f"CUDA implementation not available: {e}")
self.f_cuda = None
except ImportError as e:
logger.warning(f"CUDA implementation not available: {e}")
logger.warning(
"Did you forget to install the CUDA version of cuequivariance-ops-torch?\n"
"Install it with one of the following commands:\n"
"pip install cuequivariance-ops-torch-cu11\n"
"pip install cuequivariance-ops-torch-cu12"
)
self.f_cuda = None

self.f_fx = _tensor_product_fx(
Expand Down Expand Up @@ -88,6 +97,9 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
Raises:
RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA.
"""
if any(x.numel() == 0 for x in args):
use_fallback = True # Empty tensors are not supported by the CUDA kernel

if (
args
and args[0].device.type == "cuda"
Expand All @@ -113,18 +125,14 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
def _tensor_product_fx(
descriptor: stp.SegmentedTensorProduct,
device: Optional[torch.device],
math_dtype: Optional[torch.dtype],
math_dtype: torch.dtype,
optimize_einsums: bool,
) -> torch.nn.Module:
"""
batch support of this function:
- at least one input operand should have a batch dimension (ndim=2)
- the output operand will have a batch dimension (ndim=2)
"""

if math_dtype is None:
math_dtype = torch.get_default_dtype()

descriptor = descriptor.remove_zero_paths()
descriptor = descriptor.remove_empty_segments()

Expand Down Expand Up @@ -285,7 +293,7 @@ def forward(self, *args):
(math.prod(shape), arg.shape[-1])
)
if math.prod(arg.shape[:-1]) > 1
else arg.reshape((1, arg.shape[-1]))
else arg.reshape((math.prod(arg.shape[:-1]), arg.shape[-1]))
)
for arg in args
]
Expand All @@ -310,7 +318,7 @@ def _sum(tensors, *, shape=None, like=None):
def _tensor_product_cuda(
descriptor: stp.SegmentedTensorProduct,
device: Optional[torch.device],
math_dtype: Optional[torch.dtype],
math_dtype: torch.dtype,
) -> torch.nn.Module:
logger.debug(f"Starting search for a cuda kernel for {descriptor}")

Expand All @@ -323,9 +331,6 @@ def _tensor_product_cuda(
f" Got {descriptor.subscripts}."
)

if math_dtype is None:
math_dtype = torch.get_default_dtype()

if not torch.cuda.is_available():
raise NotImplementedError("CUDA is not available.")

Expand Down Expand Up @@ -438,12 +443,10 @@ def forward(
self,
x0: torch.Tensor,
x1: torch.Tensor,
b2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x0, x1 = self._perm(x0, x1)
assert x0.ndim >= 1, x0.ndim
assert x1.ndim >= 1, x1.ndim
assert b2 is None

shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1])
x0 = _reshape(x0, shape)
Expand Down Expand Up @@ -499,13 +502,11 @@ def forward(
x0: torch.Tensor,
x1: torch.Tensor,
x2: torch.Tensor,
b3: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x0, x1, x2 = self._perm(x0, x1, x2)
assert x0.ndim >= 1, x0.ndim
assert x1.ndim >= 1, x1.ndim
assert x2.ndim >= 1, x2.ndim
assert b3 is None

shape = torch.broadcast_shapes(x0.shape[:-1], x1.shape[:-1], x2.shape[:-1])
x0 = _reshape(x0, shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
@pytest.mark.parametrize("dtype", [torch.float64, torch.float32])
@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir])
@pytest.mark.parametrize("original_mace", [True, False])
def test_symmetric_contraction(dtype, layout, original_mace):
@pytest.mark.parametrize("batch", [0, 32])
def test_symmetric_contraction(dtype, layout, original_mace, batch):
mul = 64
irreps_in = mul * cue.Irreps("O3", "0e + 1o + 2e")
irreps_out = mul * cue.Irreps("O3", "0e + 1o")
Expand All @@ -48,12 +49,11 @@ def test_symmetric_contraction(dtype, layout, original_mace):
original_mace=original_mace,
)

Z = 32
x = torch.randn((Z, irreps_in.dim), dtype=dtype).cuda()
indices = torch.randint(0, 5, (Z,), dtype=torch.int32).cuda()
x = torch.randn((batch, irreps_in.dim), dtype=dtype).cuda()
indices = torch.randint(0, 5, (batch,), dtype=torch.int32).cuda()

out = m(x, indices)
assert out.shape == (Z, irreps_out.dim)
assert out.shape == (batch, irreps_out.dim)


def from64(shape: tuple[int, ...], data: str) -> torch.Tensor:
Expand Down
8 changes: 5 additions & 3 deletions cuequivariance_torch/tests/operations/tp_channel_wise_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cuequivariance import descriptors

list_of_irreps = [
cue.Irreps("O3", "4x0e + 4x1o"),
cue.Irreps("O3", "32x0e + 32x1o"),
cue.Irreps("O3", "2x1o + 5x0e + 2e + 1e + 1o"),
cue.Irreps("O3", "2e + 0x0e + 0o + 0x1e + 1e"),
]
Expand All @@ -31,12 +31,14 @@
@pytest.mark.parametrize("irreps3", list_of_irreps)
@pytest.mark.parametrize("layout", [cue.ir_mul, cue.mul_ir])
@pytest.mark.parametrize("use_fallback", [False, True])
@pytest.mark.parametrize("batch", [0, 32])
def test_channel_wise(
irreps1: cue.Irreps,
irreps2: cue.Irreps,
irreps3: cue.Irreps,
layout: cue.IrrepsLayout,
use_fallback: bool,
batch: int,
):
m = cuet.ChannelWiseTensorProduct(
irreps1,
Expand All @@ -49,8 +51,8 @@ def test_channel_wise(
dtype=torch.float64,
)

x1 = torch.randn(32, irreps1.dim, dtype=torch.float64).cuda()
x2 = torch.randn(32, irreps2.dim, dtype=torch.float64).cuda()
x1 = torch.randn(batch, irreps1.dim, dtype=torch.float64).cuda()
x2 = torch.randn(batch, irreps2.dim, dtype=torch.float64).cuda()

out1 = m(x1, x2, use_fallback=use_fallback)

Expand Down
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Change Log
# Changelog

```{include} ../CHANGELOG.md
Loading