From 87738da973a1a3fb23d9fef67af7378071ddb167 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 8 Jan 2025 13:53:58 +0100 Subject: [PATCH] Add batch size parameterization to tensor product tests and implement empty tensor test --- .../equivariant_tensor_product_test.py | 13 ++++++------ .../symmetric_tensor_product_test.py | 7 ++++--- .../tests/primitives/tensor_product_test.py | 20 ++++++++----------- .../tests/primitives/transpose_test.py | 14 +++++++++++++ 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index dd4389ba..6ba62ab9 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -130,20 +130,23 @@ def f1(): ] -@pytest.mark.parametrize("e", make_descriptors()) +@pytest.mark.parametrize("batch_size", [0, 5]) @pytest.mark.parametrize("dtype, math_dtype, atol, rtol", settings2) +@pytest.mark.parametrize("e", make_descriptors()) def test_precision_cuda_vs_fx( e: cue.EquivariantTensorProduct, dtype: torch.dtype, math_dtype: torch.dtype, atol: float, rtol: float, + batch_size: int, ): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") inputs = [ - torch.randn((1024, inp.dim), device=device, dtype=dtype) for inp in e.inputs + torch.randn((batch_size, inp.dim), device=device, dtype=dtype) + for inp in e.inputs ] m = cuet.EquivariantTensorProduct( e, layout=cue.ir_mul, device=device, math_dtype=math_dtype, use_fallback=False @@ -151,11 +154,7 @@ def test_precision_cuda_vs_fx( y0 = m(inputs) m = cuet.EquivariantTensorProduct( - e, - layout=cue.ir_mul, - device=device, - math_dtype=torch.float64, - use_fallback=True, + e, layout=cue.ir_mul, device=device, math_dtype=torch.float64, use_fallback=True ) inputs = [x.to(torch.float64) for x in inputs] y1 = m(inputs).to(dtype) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 9662e859..d5ae3010 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -55,10 +55,11 @@ def make_descriptors(): ] +@pytest.mark.parametrize("batch_size", [0, 3]) @pytest.mark.parametrize("ds", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings1) def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( - ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float + ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float, batch_size: int ): use_fallback = not torch.cuda.is_available() @@ -67,9 +68,9 @@ def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( ) x0 = torch.randn((2, m.x0_size), device=device, dtype=dtype, requires_grad=True) - i0 = torch.tensor([0, 1, 0], dtype=torch.int32, device=device) + i0 = torch.randint(0, x0.size(0), (batch_size,), dtype=torch.int32, device=device) x1 = torch.randn( - (i0.size(0), m.x1_size), device=device, dtype=dtype, requires_grad=True + (batch_size, m.x1_size), device=device, dtype=dtype, requires_grad=True ) x0_ = x0.clone().to(torch.float64) x1_ = x1.clone().to(torch.float64) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index a54a5309..bfd24fa8 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -90,22 +90,24 @@ def make_descriptors(): ] -@pytest.mark.parametrize("d", make_descriptors()) -@pytest.mark.parametrize("dtype, math_dtype, tol", settings) +@pytest.mark.parametrize("batch_size", [0, 3]) @pytest.mark.parametrize("use_fallback", [True, False]) +@pytest.mark.parametrize("dtype, math_dtype, tol", settings) +@pytest.mark.parametrize("d", make_descriptors()) def test_primitive_tensor_product_cuda_vs_fx( d: cue.SegmentedTensorProduct, dtype: torch.dtype, math_dtype: torch.dtype, tol: float, use_fallback: bool, + batch_size: int, ): if use_fallback is False and not torch.cuda.is_available(): pytest.skip("CUDA is not available") inputs = [ torch.randn( - (12, d.operands[i].size), + (batch_size, d.operands[i].size), device=device, dtype=dtype, requires_grad=True, @@ -114,25 +116,19 @@ def test_primitive_tensor_product_cuda_vs_fx( ] m = cuet.TensorProduct( - d, - device=device, - math_dtype=math_dtype, - use_fallback=use_fallback, + d, device=device, math_dtype=math_dtype, use_fallback=use_fallback ) out1 = m(inputs) m = cuet.TensorProduct( - d, - device=device, - math_dtype=torch.float64, - use_fallback=True, + d, device=device, math_dtype=torch.float64, use_fallback=True ) inputs_ = [inp.to(torch.float64) for inp in inputs] out2 = m(inputs_) - assert out1.shape[:-1] == (12,) + assert out1.shape[:-1] == (batch_size,) assert out1.dtype == dtype torch.testing.assert_close(out1, out2.to(dtype), atol=tol, rtol=tol) diff --git a/cuequivariance_torch/tests/primitives/transpose_test.py b/cuequivariance_torch/tests/primitives/transpose_test.py index f1b32d70..e6254d99 100644 --- a/cuequivariance_torch/tests/primitives/transpose_test.py +++ b/cuequivariance_torch/tests/primitives/transpose_test.py @@ -53,6 +53,20 @@ def test_transpose(use_fallback: bool, dtype: torch.dtype): torch.testing.assert_close(m(x), xt) +@pytest.mark.parametrize("use_fallback", [False, True]) +@pytest.mark.parametrize("dtype", dtypes) +def test_transpose_empty_tensor(use_fallback: bool, dtype: torch.dtype): + if use_fallback is False and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + x = torch.zeros((0, 10), dtype=dtype, device=device) + segments = [(2, 3), (2, 2)] + xt = torch.zeros((0, 10), dtype=dtype, device=device) + + m = cuet.TransposeSegments(segments, device, use_fallback=use_fallback) + torch.testing.assert_close(m(x), xt) + + export_modes = ["compile", "script", "jit"]