Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[regression test] Update run_mantaray_jobs.py for splited test order for PyTorch regression test on TPU #572

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
28 changes: 24 additions & 4 deletions dags/mantaray/run_mantaray_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,21 @@
)
xlml_jobs = yaml.safe_load(xlml_jobs_yaml)

# Create a DAG for PyTorch/XLA tests
pattern = r"^(ptxla|pytorchxla).*"
# Create two DAG for PyTorch/XLA tests
pattern = r"^(ptxla|pytorchxla_part1).*"
pattern2 = r"^(pytorchxla_part2).*"
workload_file_name_list = []
workload_file_name_list_2 = []
for job in xlml_jobs:
if re.match(pattern, job["task_name"]):
workload_file_name_list.append(job["file_name"])
elif re.match(pattern2, job["task_name"]):
workload_file_name_list_2.append(job["file_name"])

# merge all PyTorch/XLA tests ino one Dag
with models.DAG(
dag_id="pytorch_xla_model_regression_test_on_trillium",
schedule="0 0 * * *", # everyday at midnight # job["schedule"],
schedule="0 0 * * *", # everyday at midnight
tags=["mantaray", "pytorchxla", "xlml"],
start_date=datetime.datetime(2024, 4, 22),
catchup=False,
Expand All @@ -54,9 +58,25 @@
)
run_workload

# split out sd2 model test
with models.DAG(
dag_id="pytorch_xla_model_regression_test_on_trillium_sd2",
schedule="0 0 * * *", # everyday at midnight # job["schedule"],
tags=["mantaray", "pytorchxla", "xlml"],
start_date=datetime.datetime(2024, 4, 22),
catchup=False,
) as dag:
for workload_file_name in workload_file_name_list_2:
run_workload = mantaray.run_workload.override(
task_id=workload_file_name.split(".")[0]
)(
workload_file_name=workload_file_name,
)
run_workload

# Create a DAG for each job from maxtext
for job in xlml_jobs:
if not re.match(pattern, job["task_name"]):
if (not re.match(pattern, job["task_name"])) and (not re.match(pattern2, job["task_name"])):
with models.DAG(
dag_id=job["task_name"],
schedule=job["schedule"],
Expand Down
Loading