diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 113fc39..fcc0604 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -131,6 +131,8 @@ def __init__( ) self.symm_tp = None + self.operands_dims = [op.irreps.dim for op in e.operands] + def extra_repr(self) -> str: return str(self.etp) @@ -146,8 +148,8 @@ def forward( inputs: list[torch.Tensor] = list(inputs) assert len(inputs) == len(self.etp.inputs) - for a, b in zip(inputs, self.etp.inputs): - assert a.shape[-1] == b.irreps.dim + for a, dim in zip(inputs, self.operands_dims): + assert a.shape[-1] == dim # Transpose inputs inputs = [ diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 736c2aa..a133b44 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -298,9 +298,6 @@ def forward(self, *args): for arg in args ] - logger.debug( - f"Calling torch.fx tensor product: {self.descriptor}, input shapes: {', '.join(str(arg.shape) for arg in args)}" - ) out = self.module(*args) return out.reshape(shape + (out.shape[-1],)) diff --git a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py index d00e6ba..045c6ba 100644 --- a/cuequivariance_torch/tests/operations/tp_fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/tp_fully_connected_test.py @@ -66,3 +66,18 @@ def test_fully_connected( ).to(out1.dtype) torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) + + +def test_compile(): + m = cuet.FullyConnectedTensorProduct( + irreps_in1=cue.Irreps("O3", "32x0e + 32x1o"), + irreps_in2=cue.Irreps("O3", "32x0e + 32x1o"), + irreps_out=cue.Irreps("O3", "32x0e + 32x1o"), + layout=cue.mul_ir, + optimize_fallback=False, + ) + + m_compile = torch.compile(m, fullgraph=True) + input1 = torch.randn(100, m.irreps_in1.dim) + input2 = torch.randn(100, m.irreps_in2.dim) + m_compile(input1, input2) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index aa1b0dd..1bfd619 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -143,3 +143,14 @@ def test_precision_cuda_vs_fx( y1 = m(*inputs, use_fallback=True).to(dtype) torch.testing.assert_close(y0, y1, atol=atol, rtol=rtol) + + +def test_compile(): + e = cue.descriptors.symmetric_contraction( + cue.Irreps("O3", "32x0e + 32x1o"), cue.Irreps("O3", "32x0e + 32x1o"), [1, 2, 3] + ) + m = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, optimize_fallback=False) + m_compile = torch.compile(m, fullgraph=True) + input1 = torch.randn(100, e.inputs[0].irreps.dim) + input2 = torch.randn(100, e.inputs[1].irreps.dim) + m_compile(input1, input2)