Skip to content

Commit

Permalink
List as inputs (#28)
Browse files Browse the repository at this point in the history
* test and quick fix for zero batch

* trigger uniform 1d in test

* satisfy linter

Signed-off-by: Mario Geiger <mgeiger@nvidia.com>

* from typing import

* determine math_dtype earlier

* warning with pip commands

* remove unused argument

* changelog

* list of inputs

* add Fixed subtite

* changelog

* fix

---------

Signed-off-by: Mario Geiger <mgeiger@nvidia.com>
  • Loading branch information
mariogeiger authored Dec 3, 2024
1 parent 815289a commit fc43247
Show file tree
Hide file tree
Showing 16 changed files with 45 additions and 42 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Latest Changes

### Changed

- `cuequivariance_torch.TensorProduct` and `cuequivariance_torch.EquivariantTensorProduct` now require lists of `torch.Tensor` as input.

### Fixed

- Add support for empty batch dimension in `cuequivariance-torch`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@ def forward(
if not self.shared_weights and weight.ndim != 2:
raise ValueError("Weights should be 2D tensor")

return self.f(weight, x, use_fallback=use_fallback)
return self.f([weight, x], use_fallback=use_fallback)
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def forward(
encodings_alpha = encode_rotation_angle(alpha, self.lmax)

return self.f(
encodings_gamma,
encodings_beta,
encodings_alpha,
x,
[encodings_gamma, encodings_beta, encodings_alpha, x],
use_fallback=use_fallback,
)

Expand Down Expand Up @@ -194,4 +191,4 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the inversion layer."""
return self.f(x)
return self.f([x])
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ def spherical_harmonics(
math_dtype=x.dtype,
optimize_fallback=optimize_fallback,
)
y = m(x)
y = m([x])
y = y.reshape(vectors.shape[:-1] + (y.shape[-1],))
return y
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,4 @@ def forward(
weight = self.weight
weight = weight.flatten(1)

return self.f(weight, x, indices=indices, use_fallback=use_fallback)
return self.f([weight, x], indices=indices, use_fallback=use_fallback)
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,4 @@ def forward(
if not self.shared_weights and weight.ndim != 2:
raise ValueError("Weights should be 2D tensor")

return self.f(weight, x1, x2, use_fallback=use_fallback)
return self.f([weight, x1, x2], use_fallback=use_fallback)
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,4 @@ def forward(
if not self.shared_weights and weight.ndim != 2:
raise ValueError("Weights should be 2D tensor")

return self.f(weight, x1, x2, use_fallback=use_fallback)
return self.f([weight, x1, x2], use_fallback=use_fallback)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import List, Optional, Union

import torch

Expand Down Expand Up @@ -41,7 +41,7 @@ class EquivariantTensorProduct(torch.nn.Module):
>>> x1 = torch.ones(17, e.inputs[1].irreps.dim)
>>> x2 = torch.ones(17, e.inputs[2].irreps.dim)
>>> tp = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul)
>>> tp(w, x1, x2)
>>> tp([w, x1, x2])
tensor([[0., 0., 0., 0., 0., 0.],
...
[0., 0., 0., 0., 0., 0.]])
Expand All @@ -50,7 +50,7 @@ class EquivariantTensorProduct(torch.nn.Module):
>>> w = torch.ones(3, e.inputs[0].irreps.dim)
>>> indices = torch.randint(3, (17,))
>>> tp(w, x1, x2, indices=indices)
>>> tp([w, x1, x2], indices=indices)
tensor([[0., 0., 0., 0., 0., 0.],
...
[0., 0., 0., 0., 0., 0.]])
Expand Down Expand Up @@ -138,14 +138,14 @@ def extra_repr(self) -> str:

def forward(
self,
*inputs: torch.Tensor,
inputs: List[torch.Tensor],
indices: Optional[torch.Tensor] = None,
use_fallback: Optional[bool] = None,
) -> torch.Tensor:
"""
If ``indices`` is not None, the first input is indexed by ``indices``.
"""
inputs: list[torch.Tensor] = list(inputs)
inputs: List[torch.Tensor] = list(inputs)

assert len(inputs) == len(self.etp.inputs)
for a, dim in zip(inputs, self.operands_dims):
Expand All @@ -164,7 +164,7 @@ def forward(
# TODO: at some point we will have kernel for this
assert len(inputs) >= 1
inputs[0] = inputs[0][indices]
output = self.tp(*inputs, use_fallback=use_fallback)
output = self.tp(inputs, use_fallback=use_fallback)

if self.symm_tp is not None:
if len(inputs) == 1:
Expand All @@ -174,6 +174,10 @@ def forward(
if len(inputs) == 2:
[x0, x1] = inputs
if indices is None:
torch._assert(
x0.ndim == 2,
f"Expected x0 to have shape (batch, dim), got {x0.shape}",
)
if x0.shape[0] == 1:
indices = torch.zeros(
(x1.shape[0],), dtype=torch.int32, device=x1.device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(
use_fallback=use_fallback,
)
if self.f0 is not None:
out += self.f0()
out += self.f0([])
return out


Expand Down Expand Up @@ -201,7 +201,7 @@ def forward(

torch._assert(
x0.ndim == 2,
f"Expected 2 dims (i0.max() + 1, x0_size), got {x0.ndim}",
f"Expected 2 dims (i0.max() + 1, x0_size), got shape {x0.shape}",
)
shape = torch.broadcast_shapes(i0.shape, x1.shape[:-1])
i0 = i0.expand(shape).reshape((math.prod(shape),))
Expand Down Expand Up @@ -368,6 +368,6 @@ def forward(
self, x0: torch.Tensor, i0: torch.Tensor, x1: torch.Tensor
) -> torch.Tensor:
return sum(
f(x0[i0], *[x1] * (f.descriptor.num_operands - 2), use_fallback=True)
f([x0[i0]] + [x1] * (f.descriptor.num_operands - 2), use_fallback=True)
for f in self.fs
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
import math
import warnings
from typing import Optional, OrderedDict, Tuple
from typing import List, Optional, OrderedDict, Tuple

import torch
import torch.fx
Expand Down Expand Up @@ -76,12 +76,12 @@ def __repr__(self):
)
return f"TensorProduct({self.descriptor} {has_cuda_kernel})"

def forward(self, *args, use_fallback: Optional[bool] = None):
def forward(self, inputs: List[torch.Tensor], use_fallback: Optional[bool] = None):
r"""
Perform the tensor product based on the specified descriptor.
Args:
args (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one.
inputs (list of torch.Tensor): The input tensors. The number of input tensors should match the number of operands in the descriptor minus one.
Each input tensor should have a shape of ((batch,) operand_size), where `operand_size` corresponds to the size
of each operand as defined in the tensor product descriptor.
use_fallback (bool, optional): Determines the computation method. If `None` (default), a CUDA kernel will be used if available and the input
Expand All @@ -97,16 +97,16 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
Raises:
RuntimeError: If `use_fallback` is `False` and either no CUDA kernel is available or the input tensor is not on CUDA.
"""
if any(x.numel() == 0 for x in args):
if any(x.numel() == 0 for x in inputs):
use_fallback = True # Empty tensors are not supported by the CUDA kernel

if (
args
and args[0].device.type == "cuda"
inputs
and inputs[0].device.type == "cuda"
and self.f_cuda is not None
and (use_fallback is not True)
):
return self.f_cuda(*args)
return self.f_cuda(*inputs)

if use_fallback is False:
if self.f_cuda is not None:
Expand All @@ -119,7 +119,7 @@ def forward(self, *args, use_fallback: Optional[bool] = None):
"The fallback method is used but it has not been optimized. "
"Consider setting optimize_fallback=True when creating the TensorProduct module."
)
return self.f_fx(*args)
return self.f_fx(inputs)


def _tensor_product_fx(
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProdu
self.module = module
self.descriptor = descriptor

def forward(self, *args):
def forward(self, args):
for oid, arg in enumerate(args):
torch._assert(
arg.shape[-1] == self.descriptor.operands[oid].size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_channel_wise(
if layout == cue.mul_ir:
d = d.add_or_transpose_modes("u,ui,j,uk+ijk")
mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda()
out2 = mfx(m.weight, x1, x2, use_fallback=True)
out2 = mfx([m.weight, x1, x2], use_fallback=True)

torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def test_fully_connected(
d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk")
mfx = cuet.TensorProduct(d, math_dtype=torch.float64).cuda()
out2 = mfx(
m.weight.to(torch.float64),
x1.to(torch.float64),
x2.to(torch.float64),
[m.weight.to(torch.float64), x1.to(torch.float64), x2.to(torch.float64)],
use_fallback=True,
).to(out1.dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def test_performance_cuda_vs_fx(
]

for _ in range(10):
m(*inputs, use_fallback=False)
m(*inputs, use_fallback=True)
m(inputs, use_fallback=False)
m(inputs, use_fallback=True)

def f(ufb: bool):
m(*inputs, use_fallback=ufb)
m(inputs, use_fallback=ufb)
torch.cuda.synchronize()

t0 = timeit.Timer(lambda: f(False)).timeit(number=10)
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_precision_cuda_vs_fx(
device=device,
math_dtype=math_dtype,
)
y0 = m(*inputs, use_fallback=False)
y0 = m(inputs, use_fallback=False)

m = cuet.EquivariantTensorProduct(
e,
Expand All @@ -140,7 +140,7 @@ def test_precision_cuda_vs_fx(
optimize_fallback=True,
)
inputs = map(lambda x: x.to(torch.float64), inputs)
y1 = m(*inputs, use_fallback=True).to(dtype)
y1 = m(inputs, use_fallback=True).to(dtype)

torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol)

Expand All @@ -153,4 +153,4 @@ def test_compile():
m_compile = torch.compile(m, fullgraph=True)
input1 = torch.randn(100, e.inputs[0].irreps.dim)
input2 = torch.randn(100, e.inputs[1].irreps.dim)
m_compile(input1, input2)
m_compile([input1, input2])
4 changes: 2 additions & 2 deletions cuequivariance_torch/tests/primitives/tensor_product_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def test_primitive_tensor_product_cuda_vs_fx(
m = cuet.TensorProduct(
d, device=device, math_dtype=math_dtype, optimize_fallback=False
)
out1 = m(*inputs, use_fallback=False)
out1 = m(inputs, use_fallback=False)
m = cuet.TensorProduct(
d, device=device, math_dtype=torch.float64, optimize_fallback=False
)
inputs_ = [inp.clone().to(torch.float64) for inp in inputs]
out2 = m(*inputs_, use_fallback=True)
out2 = m(inputs_, use_fallback=True)

assert out1.shape[:-1] == torch.broadcast_shapes(*batches)
assert out1.dtype == dtype
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/etp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ We can execute an :class:`cuequivariance.EquivariantTensorProduct` with PyTorch.
w = torch.randn(e.inputs[0].irreps.dim)
x = torch.randn(e.inputs[1].irreps.dim)

module(w, x)
module([w, x])

Note that you have to specify the layout. If the layout specified is different from the one in the descriptor, the module will transpose the inputs/output to match the layout.
2 changes: 1 addition & 1 deletion docs/tutorials/stp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Now we can execute the linear layer with random input and weight tensors.
w = torch.randn(d.operands[0].size)
x1 = torch.randn(3000, irreps1.dim)

x2 = linear_torch(w, x1)
x2 = linear_torch([w, x1])

assert x2.shape == (3000, irreps2.dim)

Expand Down

0 comments on commit fc43247

Please # to comment.