diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index da872c2a2ce6ea..9221ad79dfacdb 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2557,7 +2557,7 @@ def outer(x, y, name=None): nx = x.reshape((-1, 1)) ny = y.reshape((1, -1)) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.matmul(nx, ny, False, False) else: diff --git a/test/legacy_test/test_outer.py b/test/legacy_test/test_outer.py index 67bb15323338cd..dc76d10a8a0282 100644 --- a/test/legacy_test/test_outer.py +++ b/test/legacy_test/test_outer.py @@ -17,12 +17,14 @@ import numpy as np import paddle -from paddle.static import Program, program_guard +from paddle.pir_utils import test_with_pir_api class TestMultiplyApi(unittest.TestCase): def _run_static_graph_case(self, x_data, y_data): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): paddle.enable_static() x = paddle.static.data( name='x', shape=x_data.shape, dtype=x_data.dtype @@ -53,7 +55,8 @@ def _run_dynamic_graph_case(self, x_data, y_data): res = paddle.outer(x, y) return res.numpy() - def test_multiply(self): + @test_with_pir_api + def test_multiply_static(self): np.random.seed(7) # test static computation graph: 3-d array @@ -86,6 +89,7 @@ def test_multiply(self): res = self._run_static_graph_case(x_data, y_data) np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05) + def test_multiply_dynamic(self): # test dynamic computation graph: 3-d array x_data = np.random.rand(5, 10, 10).astype(np.float64) y_data = np.random.rand(2, 10).astype(np.float64) @@ -138,14 +142,17 @@ def test_multiply(self): class TestMultiplyError(unittest.TestCase): - def test_errors(self): + def test_errors_static(self): # test static computation graph: dtype can not be int8 paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[100], dtype=np.int8) y = paddle.static.data(name='y', shape=[100], dtype=np.int8) self.assertRaises(TypeError, paddle.outer, x, y) + def test_errors_dynamic(self): np.random.seed(7) # test dynamic computation graph: dtype must be Tensor type