Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Dy2St] lower time > 100 in dy2st unittests #59506

Merged
merged 6 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions test/dygraph_to_static/test_build_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
test_default_and_pir,
test_pt_only,
)
from test_resnet import ResNetHelper

Expand Down Expand Up @@ -66,6 +67,7 @@ def verify_predict(self):
)

@test_ast_only
@test_pt_only
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -78,6 +80,7 @@ def test_resnet(self):
self.verify_predict()

@test_ast_only
@test_pt_only
def test_in_static_mode_mkldnn(self):
paddle.base.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand All @@ -88,7 +91,7 @@ def test_in_static_mode_mkldnn(self):


class TestError(Dy2StTestBase):
@test_legacy_and_pt_and_pir
@test_default_and_pir
def test_type_error(self):
def foo(x):
out = x + 1
Expand Down
11 changes: 6 additions & 5 deletions test/dygraph_to_static/test_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
set_to_static_mode,
test_legacy_only,
)
from seq2seq_dygraph_model import AttentionModel, BaseModel
from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter
Expand Down Expand Up @@ -239,13 +239,14 @@ def _test_predict(self, attn_model=False):
msg=f"\npred_dygraph = {pred_dygraph} \npred_static = {pred_static}",
)

# Disable duplicated test case to avoid timeout
@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
@set_to_static_mode(ToStaticMode.SOT)
@test_legacy_only
def test_base_model(self):
self._test_train(attn_model=False)
self._test_predict(attn_model=False)

@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
@set_to_static_mode(ToStaticMode.SOT)
@test_legacy_only
def test_attn_model(self):
self._test_train(attn_model=True)
# TODO(liym27): add predict
Expand Down