Skip to content

Commit

Permalink
【Fix PIR Unittest No.447 BUAA】Fix some test case in PIR (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#66221)

* Fix some test case in PIR

* Fix some test case in PIR

* fix cmakelist
  • Loading branch information
uanu2002 authored and lixcli committed Jul 22, 2024
1 parent 94f1791 commit d7b0131
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
1 change: 0 additions & 1 deletion test/deprecated/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,6 @@ if(NOT WIN32)
endif()

# setting timeout value as 15S
set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_lod_tensor_to_selected_rows_deprecated
PROPERTIES TIMEOUT 200)

Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ if(WITH_DISTRIBUTE)
endif()

# setting timeout value as 15S
set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_isin PROPERTIES TIMEOUT 30)
set_tests_properties(test_binomial_op PROPERTIES TIMEOUT 30)
set_tests_properties(test_run PROPERTIES TIMEOUT 120)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,19 @@ def test_cross_api(self):
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

def test_cross_api1(self):
self.input_data()
with paddle.pir_utils.OldIrGuard():
self.input_data()

main = paddle.static.Program()
startup = paddle.static.Program()
main = paddle.static.Program()
startup = paddle.static.Program()

# case 1:
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name="x", shape=[-1, 3], dtype="float32")
y = paddle.static.data(name='y', shape=[-1, 3], dtype='float32')
# case 1:
with paddle.static.program_guard(main, startup):
x = paddle.static.data(name="x", shape=[-1, 3], dtype="float32")
y = paddle.static.data(name='y', shape=[-1, 3], dtype='float32')

y_1 = paddle.cross(x, y, name='result')
self.assertEqual(('result' in y_1.name), True)
y_1 = paddle.cross(x, y, name='result')
self.assertEqual(('result' in y_1.name), True)

def test_dygraph_api(self):
self.input_data()
Expand Down

0 comments on commit d7b0131

Please # to comment.