Skip to content

Commit 3f0b85e

Browse files
committed
fix integs
1 parent b139afb commit 3f0b85e

File tree

1 file changed

+16
-32
lines changed

1 file changed

+16
-32
lines changed

Diff for: tests/integ/sagemaker/modules/train/test_model_trainer.py

+16-32
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,22 @@
3535
},
3636
}
3737

38-
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38+
PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script"
39+
PARAM_SCRIPT_SOURCE_CODE = SourceCode(
40+
source_dir=PARAM_SCRIPT_SOURCE_DIR,
41+
requirements="requirements.txt",
42+
entry_script="train.py",
43+
)
3944

45+
DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
4046

41-
def test_hp_contract_basic_py_script(modules_sagemaker_session):
42-
source_code = SourceCode(
43-
source_dir=f"{DATA_DIR}/modules/params_script",
44-
requirements="requirements.txt",
45-
entry_script="train.py",
46-
)
4747

48+
def test_hp_contract_basic_py_script(modules_sagemaker_session):
4849
model_trainer = ModelTrainer(
4950
sagemaker_session=modules_sagemaker_session,
5051
training_image=DEFAULT_CPU_IMAGE,
5152
hyperparameters=EXPECTED_HYPERPARAMETERS,
52-
source_code=source_code,
53+
source_code=PARAM_SCRIPT_SOURCE_CODE,
5354
base_job_name="hp-contract-basic-py-script",
5455
)
5556

@@ -59,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
5960
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
6061
source_code = SourceCode(
6162
source_dir=f"{DATA_DIR}/modules/params_script",
63+
requirements="requirements.txt",
6264
entry_script="train.sh",
6365
)
6466
model_trainer = ModelTrainer(
@@ -73,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
7375

7476

7577
def test_hp_contract_mpi_script(modules_sagemaker_session):
76-
source_code = SourceCode(
77-
source_dir=f"{DATA_DIR}/modules/params_script",
78-
entry_script="train.py",
79-
)
8078
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
8179
model_trainer = ModelTrainer(
8280
sagemaker_session=modules_sagemaker_session,
8381
training_image=DEFAULT_CPU_IMAGE,
8482
compute=compute,
8583
hyperparameters=EXPECTED_HYPERPARAMETERS,
86-
source_code=source_code,
84+
source_code=PARAM_SCRIPT_SOURCE_CODE,
8785
distributed=MPI(),
8886
base_job_name="hp-contract-mpi-script",
8987
)
@@ -92,17 +90,13 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
9290

9391

9492
def test_hp_contract_torchrun_script(modules_sagemaker_session):
95-
source_code = SourceCode(
96-
source_dir=f"{DATA_DIR}/modules/params_script",
97-
entry_script="train.py",
98-
)
9993
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
10094
model_trainer = ModelTrainer(
10195
sagemaker_session=modules_sagemaker_session,
10296
training_image=DEFAULT_CPU_IMAGE,
10397
compute=compute,
10498
hyperparameters=EXPECTED_HYPERPARAMETERS,
105-
source_code=source_code,
99+
source_code=PARAM_SCRIPT_SOURCE_CODE,
106100
distributed=Torchrun(),
107101
base_job_name="hp-contract-torchrun-script",
108102
)
@@ -111,33 +105,23 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
111105

112106

113107
def test_hp_contract_hyperparameter_json(modules_sagemaker_session):
114-
source_dir = f"{DATA_DIR}/modules/params_script"
115-
source_code = SourceCode(
116-
source_dir=source_dir,
117-
entry_script="train.py",
118-
)
119108
model_trainer = ModelTrainer(
120109
sagemaker_session=modules_sagemaker_session,
121110
training_image=DEFAULT_CPU_IMAGE,
122-
hyperparameters=f"{source_dir}/hyperparameters.json",
123-
source_code=source_code,
111+
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json",
112+
source_code=PARAM_SCRIPT_SOURCE_CODE,
124113
base_job_name="hp-contract-hyperparameter-json",
125114
)
126115
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
127116
model_trainer.train()
128117

129118

130119
def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session):
131-
source_dir = f"{DATA_DIR}/modules/params_script"
132-
source_code = SourceCode(
133-
source_dir=source_dir,
134-
entry_script="train.py",
135-
)
136120
model_trainer = ModelTrainer(
137121
sagemaker_session=modules_sagemaker_session,
138122
training_image=DEFAULT_CPU_IMAGE,
139-
hyperparameters=f"{source_dir}/hyperparameters.yaml",
140-
source_code=source_code,
123+
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml",
124+
source_code=PARAM_SCRIPT_SOURCE_CODE,
141125
base_job_name="hp-contract-hyperparameter-yaml",
142126
)
143127
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS

0 commit comments

Comments
 (0)