35
35
},
36
36
}
37
37
38
- DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38
+ PARAM_SCRIPT_SOURCE_DIR = f"{ DATA_DIR } /modules/params_script"
39
+ PARAM_SCRIPT_SOURCE_CODE = SourceCode (
40
+ source_dir = PARAM_SCRIPT_SOURCE_DIR ,
41
+ requirements = "requirements.txt" ,
42
+ entry_script = "train.py" ,
43
+ )
39
44
45
+ DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
40
46
41
- def test_hp_contract_basic_py_script (modules_sagemaker_session ):
42
- source_code = SourceCode (
43
- source_dir = f"{ DATA_DIR } /modules/params_script" ,
44
- requirements = "requirements.txt" ,
45
- entry_script = "train.py" ,
46
- )
47
47
48
+ def test_hp_contract_basic_py_script (modules_sagemaker_session ):
48
49
model_trainer = ModelTrainer (
49
50
sagemaker_session = modules_sagemaker_session ,
50
51
training_image = DEFAULT_CPU_IMAGE ,
51
52
hyperparameters = EXPECTED_HYPERPARAMETERS ,
52
- source_code = source_code ,
53
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
53
54
base_job_name = "hp-contract-basic-py-script" ,
54
55
)
55
56
@@ -59,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
59
60
def test_hp_contract_basic_sh_script (modules_sagemaker_session ):
60
61
source_code = SourceCode (
61
62
source_dir = f"{ DATA_DIR } /modules/params_script" ,
63
+ requirements = "requirements.txt" ,
62
64
entry_script = "train.sh" ,
63
65
)
64
66
model_trainer = ModelTrainer (
@@ -73,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
73
75
74
76
75
77
def test_hp_contract_mpi_script (modules_sagemaker_session ):
76
- source_code = SourceCode (
77
- source_dir = f"{ DATA_DIR } /modules/params_script" ,
78
- entry_script = "train.py" ,
79
- )
80
78
compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
81
79
model_trainer = ModelTrainer (
82
80
sagemaker_session = modules_sagemaker_session ,
83
81
training_image = DEFAULT_CPU_IMAGE ,
84
82
compute = compute ,
85
83
hyperparameters = EXPECTED_HYPERPARAMETERS ,
86
- source_code = source_code ,
84
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
87
85
distributed = MPI (),
88
86
base_job_name = "hp-contract-mpi-script" ,
89
87
)
@@ -92,17 +90,13 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
92
90
93
91
94
92
def test_hp_contract_torchrun_script (modules_sagemaker_session ):
95
- source_code = SourceCode (
96
- source_dir = f"{ DATA_DIR } /modules/params_script" ,
97
- entry_script = "train.py" ,
98
- )
99
93
compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
100
94
model_trainer = ModelTrainer (
101
95
sagemaker_session = modules_sagemaker_session ,
102
96
training_image = DEFAULT_CPU_IMAGE ,
103
97
compute = compute ,
104
98
hyperparameters = EXPECTED_HYPERPARAMETERS ,
105
- source_code = source_code ,
99
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
106
100
distributed = Torchrun (),
107
101
base_job_name = "hp-contract-torchrun-script" ,
108
102
)
@@ -111,33 +105,23 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
111
105
112
106
113
107
def test_hp_contract_hyperparameter_json (modules_sagemaker_session ):
114
- source_dir = f"{ DATA_DIR } /modules/params_script"
115
- source_code = SourceCode (
116
- source_dir = source_dir ,
117
- entry_script = "train.py" ,
118
- )
119
108
model_trainer = ModelTrainer (
120
109
sagemaker_session = modules_sagemaker_session ,
121
110
training_image = DEFAULT_CPU_IMAGE ,
122
- hyperparameters = f"{ source_dir } /hyperparameters.json" ,
123
- source_code = source_code ,
111
+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.json" ,
112
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
124
113
base_job_name = "hp-contract-hyperparameter-json" ,
125
114
)
126
115
assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
127
116
model_trainer .train ()
128
117
129
118
130
119
def test_hp_contract_hyperparameter_yaml (modules_sagemaker_session ):
131
- source_dir = f"{ DATA_DIR } /modules/params_script"
132
- source_code = SourceCode (
133
- source_dir = source_dir ,
134
- entry_script = "train.py" ,
135
- )
136
120
model_trainer = ModelTrainer (
137
121
sagemaker_session = modules_sagemaker_session ,
138
122
training_image = DEFAULT_CPU_IMAGE ,
139
- hyperparameters = f"{ source_dir } /hyperparameters.yaml" ,
140
- source_code = source_code ,
123
+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.yaml" ,
124
+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
141
125
base_job_name = "hp-contract-hyperparameter-yaml" ,
142
126
)
143
127
assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
0 commit comments