Skip to content

Commit

Permalink
Avoid calling irreps.dim and logger in forward (fix for #32) (#35)
Browse files Browse the repository at this point in the history
* avoid calling irreps.dim and logger in forward

* add tests

* optimize_fallback=True
  • Loading branch information
mariogeiger authored Dec 2, 2024
1 parent f4b52be commit 815289a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def __init__(
)
self.symm_tp = None

self.operands_dims = [op.irreps.dim for op in e.operands]

def extra_repr(self) -> str:
return str(self.etp)

Expand All @@ -146,8 +148,8 @@ def forward(
inputs: list[torch.Tensor] = list(inputs)

assert len(inputs) == len(self.etp.inputs)
for a, b in zip(inputs, self.etp.inputs):
assert a.shape[-1] == b.irreps.dim
for a, dim in zip(inputs, self.operands_dims):
assert a.shape[-1] == dim

# Transpose inputs
inputs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,6 @@ def forward(self, *args):
for arg in args
]

logger.debug(
f"Calling torch.fx tensor product: {self.descriptor}, input shapes: {', '.join(str(arg.shape) for arg in args)}"
)
out = self.module(*args)

return out.reshape(shape + (out.shape[-1],))
Expand Down
15 changes: 15 additions & 0 deletions cuequivariance_torch/tests/operations/tp_fully_connected_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,18 @@ def test_fully_connected(
).to(out1.dtype)

torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)


def test_compile():
m = cuet.FullyConnectedTensorProduct(
irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"),
irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"),
irreps_out=cue.Irreps("O3", "32x0e + 32x1o"),
layout=cue.mul_ir,
optimize_fallback=False,
)

m_compile = torch.compile(m, fullgraph=True)
input1 = torch.randn(100, m.irreps_in1.dim)
input2 = torch.randn(100, m.irreps_in2.dim)
m_compile(input1, input2)
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,14 @@ def test_precision_cuda_vs_fx(
y1 = m(*inputs, use_fallback=True).to(dtype)

torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol)


def test_compile():
e = cue.descriptors.symmetric_contraction(
cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3]
)
m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False)
m_compile = torch.compile(m, fullgraph=True)
input1 = torch.randn(100, e.inputs[0].irreps.dim)
input2 = torch.randn(100, e.inputs[1].irreps.dim)
m_compile(input1, input2)

0 comments on commit 815289a

Please # to comment.