From d7b0131e718587579096a8f1e4f7e122829b28df Mon Sep 17 00:00:00 2001 From: Uanu <92366232+uanu2002@users.noreply.github.com> Date: Mon, 22 Jul 2024 12:14:32 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Fix=20PIR=20Unittest=20No.447=20BUAA?= =?UTF-8?q?=E3=80=91Fix=20some=20test=20case=20in=20PIR=20(#66221)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix some test case in PIR * Fix some test case in PIR * fix cmakelist --- test/deprecated/legacy_test/CMakeLists.txt | 1 - test/legacy_test/CMakeLists.txt | 1 + .../legacy_test/test_cross_op.py | 19 ++++++++++--------- 3 files changed, 11 insertions(+), 10 deletions(-) rename test/{deprecated => }/legacy_test/test_cross_op.py (93%) diff --git a/test/deprecated/legacy_test/CMakeLists.txt b/test/deprecated/legacy_test/CMakeLists.txt index fa06867db3bcb6..e2803903c4d493 100644 --- a/test/deprecated/legacy_test/CMakeLists.txt +++ b/test/deprecated/legacy_test/CMakeLists.txt @@ -609,7 +609,6 @@ if(NOT WIN32) endif() # setting timeout value as 15S -set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_lod_tensor_to_selected_rows_deprecated PROPERTIES TIMEOUT 200) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 70d6d1919ddb27..fb7d5a73e7c824 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -778,6 +778,7 @@ if(WITH_DISTRIBUTE) endif() # setting timeout value as 15S +set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120) set_tests_properties(test_isin PROPERTIES TIMEOUT 30) set_tests_properties(test_binomial_op PROPERTIES TIMEOUT 30) set_tests_properties(test_run PROPERTIES TIMEOUT 120) diff --git a/test/deprecated/legacy_test/test_cross_op.py b/test/legacy_test/test_cross_op.py similarity index 93% rename from test/deprecated/legacy_test/test_cross_op.py rename to test/legacy_test/test_cross_op.py index 5f35f5099d5d44..4393b428f5abbd 100644 --- a/test/deprecated/legacy_test/test_cross_op.py +++ b/test/legacy_test/test_cross_op.py @@ -215,18 +215,19 @@ def test_cross_api(self): np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) def test_cross_api1(self): - self.input_data() + with paddle.pir_utils.OldIrGuard(): + self.input_data() - main = paddle.static.Program() - startup = paddle.static.Program() + main = paddle.static.Program() + startup = paddle.static.Program() - # case 1: - with paddle.static.program_guard(main, startup): - x = paddle.static.data(name="x", shape=[-1, 3], dtype="float32") - y = paddle.static.data(name='y', shape=[-1, 3], dtype='float32') + # case 1: + with paddle.static.program_guard(main, startup): + x = paddle.static.data(name="x", shape=[-1, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[-1, 3], dtype='float32') - y_1 = paddle.cross(x, y, name='result') - self.assertEqual(('result' in y_1.name), True) + y_1 = paddle.cross(x, y, name='result') + self.assertEqual(('result' in y_1.name), True) def test_dygraph_api(self): self.input_data()