Skip to content

feat: Allow ModelTrainer to accept hyperparameters file #5059

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 15 commits into from
Mar 5, 2025
Merged
32 changes: 28 additions & 4 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import json
import shutil
from tempfile import TemporaryDirectory

from typing import Optional, List, Union, Dict, Any, ClassVar
import yaml

from graphene.utils.str_converters import to_camel_case, to_snake_case

Expand Down Expand Up @@ -195,8 +195,9 @@ class ModelTrainer(BaseModel):
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
hyperparameters (Optional[Dict[str, Any]]):
The hyperparameters for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]):
The hyperparameters for the training job. Can be a dictionary of hyperparameters
or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
Expand Down Expand Up @@ -226,7 +227,7 @@ class ModelTrainer(BaseModel):
checkpoint_config: Optional[CheckpointConfig] = None
training_input_mode: Optional[str] = "File"
environment: Optional[Dict[str, str]] = {}
hyperparameters: Optional[Dict[str, Any]] = {}
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
tags: Optional[List[Tag]] = None
local_container_root: Optional[str] = os.getcwd()

Expand Down Expand Up @@ -470,6 +471,29 @@ def model_post_init(self, __context: Any):
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
)

if self.hyperparameters and isinstance(self.hyperparameters, str):
if not os.path.exists(self.hyperparameters):
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
with open(self.hyperparameters, "r") as f:
contents = f.read()
try:
self.hyperparameters = json.loads(contents)
logger.debug("Hyperparameters loaded as JSON")
except json.JSONDecodeError:
try:
logger.info(f"contents: {contents}")
self.hyperparameters = yaml.safe_load(contents)
if not isinstance(self.hyperparameters, dict):
raise ValueError("YAML contents must be a valid mapping")
logger.info(f"hyperparameters: {self.hyperparameters}")
logger.debug("Hyperparameters loaded as YAML")
except (yaml.YAMLError, ValueError):
raise ValueError(
f"Invalid hyperparameters file: {self.hyperparameters}. "
"Must be a valid JSON or YAML file."
)

if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
session = self.sagemaker_session
base_job_name = self.base_job_name
Expand Down
15 changes: 15 additions & 0 deletions tests/data/modules/params_script/hyperparameters.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"integer": 1,
"boolean": true,
"float": 3.14,
"string": "Hello World",
"list": [1, 2, 3],
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": true
}
}
19 changes: 19 additions & 0 deletions tests/data/modules/params_script/hyperparameters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
integer: 1
boolean: true
float: 3.14
string: "Hello World"
list:
- 1
- 2
- 3
dict:
string: value
integer: 3
float: 3.14
list:
- 1
- 2
- 3
dict:
key: value
boolean: true
1 change: 1 addition & 0 deletions tests/data/modules/params_script/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
omegaconf
97 changes: 94 additions & 3 deletions tests/data/modules/params_script/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import argparse
import json
import os
from typing import List, Dict, Any
from dataclasses import dataclass
from omegaconf import OmegaConf

EXPECTED_HYPERPARAMETERS = {
"integer": 1,
Expand All @@ -26,6 +29,7 @@
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": True,
Expand Down Expand Up @@ -117,7 +121,7 @@ def main():
assert isinstance(params["dict"], dict)

params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
print(params)
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
Expand All @@ -132,9 +136,96 @@ def main():
assert isinstance(params["float"], float)
assert isinstance(params["list"], list)
assert isinstance(params["dict"], dict)
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")

print("Test passed.")
# Local JSON - DictConfig OmegaConf
params = OmegaConf.load("hyperparameters.json")

print(f"Local hyperparameters.json: {params}")
assert params.string == EXPECTED_HYPERPARAMETERS["string"]
assert params.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert params.float == EXPECTED_HYPERPARAMETERS["float"]
assert params.list == EXPECTED_HYPERPARAMETERS["list"]
assert params.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]

@dataclass
class DictConfig:
string: str
integer: int
boolean: bool
float: float
list: List[int]
dict: Dict[str, Any]

@dataclass
class HPConfig:
string: str
integer: int
boolean: bool
float: float
list: List[int]
dict: DictConfig

# Local JSON - Structured OmegaConf
hp_config: HPConfig = OmegaConf.merge(
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json")
)
print(f"Local hyperparameters.json - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]

# Local YAML - Structured OmegaConf
hp_config: HPConfig = OmegaConf.merge(
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml")
)
print(f"Local hyperparameters.yaml - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
print(f"hyperparameters.yaml -> hyperparameters: {hp_config}")

# HP Dict - Structured OmegaConf
hp_dict = json.loads(os.environ["SM_HPS"])
hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict))
print(f"SM_HPS - Structured: {hp_config}")
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
print(f"SM_HPS -> hyperparameters: {hp_config}")


if __name__ == "__main__":
Expand Down
52 changes: 36 additions & 16 deletions tests/integ/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,29 @@
"dict": {
"string": "value",
"integer": 3,
"float": 3.14,
"list": [1, 2, 3],
"dict": {"key": "value"},
"boolean": True,
},
}

PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script"
PARAM_SCRIPT_SOURCE_CODE = SourceCode(
source_dir=PARAM_SCRIPT_SOURCE_DIR,
requirements="requirements.txt",
entry_script="train.py",
)

DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"


def test_hp_contract_basic_py_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)

model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-basic-py-script",
)

Expand All @@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
def test_hp_contract_basic_sh_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
requirements="requirements.txt",
entry_script="train.sh",
)
model_trainer = ModelTrainer(
Expand All @@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):


def test_hp_contract_mpi_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
compute=compute,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
distributed=MPI(),
base_job_name="hp-contract-mpi-script",
)
Expand All @@ -90,19 +90,39 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):


def test_hp_contract_torchrun_script(modules_sagemaker_session):
source_code = SourceCode(
source_dir=f"{DATA_DIR}/modules/params_script",
entry_script="train.py",
)
compute = Compute(instance_type="ml.m5.xlarge", instance_count=2)
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
compute=compute,
hyperparameters=EXPECTED_HYPERPARAMETERS,
source_code=source_code,
source_code=PARAM_SCRIPT_SOURCE_CODE,
distributed=Torchrun(),
base_job_name="hp-contract-torchrun-script",
)

model_trainer.train()


def test_hp_contract_hyperparameter_json(modules_sagemaker_session):
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json",
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-hyperparameter-json",
)
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
model_trainer.train()


def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session):
model_trainer = ModelTrainer(
sagemaker_session=modules_sagemaker_session,
training_image=DEFAULT_CPU_IMAGE,
hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml",
source_code=PARAM_SCRIPT_SOURCE_CODE,
base_job_name="hp-contract-hyperparameter-yaml",
)
assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS
model_trainer.train()
Loading