diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index a47d8f91ad..bb7c4168e6 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -18,8 +18,8 @@ import json import shutil from tempfile import TemporaryDirectory - from typing import Optional, List, Union, Dict, Any, ClassVar +import yaml from graphene.utils.str_converters import to_camel_case, to_snake_case @@ -195,8 +195,9 @@ class ModelTrainer(BaseModel): Defaults to "File". environment (Optional[Dict[str, str]]): The environment variables for the training job. - hyperparameters (Optional[Dict[str, Any]]): - The hyperparameters for the training job. + hyperparameters (Optional[Union[Dict[str, Any], str]): + The hyperparameters for the training job. Can be a dictionary of hyperparameters + or a path to hyperparameters json/yaml file. tags (Optional[List[Tag]]): An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment. @@ -226,7 +227,7 @@ class ModelTrainer(BaseModel): checkpoint_config: Optional[CheckpointConfig] = None training_input_mode: Optional[str] = "File" environment: Optional[Dict[str, str]] = {} - hyperparameters: Optional[Dict[str, Any]] = {} + hyperparameters: Optional[Union[Dict[str, Any], str]] = {} tags: Optional[List[Tag]] = None local_container_root: Optional[str] = os.getcwd() @@ -470,6 +471,29 @@ def model_post_init(self, __context: Any): f"StoppingCondition not provided. Using default:\n{self.stopping_condition}" ) + if self.hyperparameters and isinstance(self.hyperparameters, str): + if not os.path.exists(self.hyperparameters): + raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}") + logger.info(f"Loading hyperparameters from file: {self.hyperparameters}") + with open(self.hyperparameters, "r") as f: + contents = f.read() + try: + self.hyperparameters = json.loads(contents) + logger.debug("Hyperparameters loaded as JSON") + except json.JSONDecodeError: + try: + logger.info(f"contents: {contents}") + self.hyperparameters = yaml.safe_load(contents) + if not isinstance(self.hyperparameters, dict): + raise ValueError("YAML contents must be a valid mapping") + logger.info(f"hyperparameters: {self.hyperparameters}") + logger.debug("Hyperparameters loaded as YAML") + except (yaml.YAMLError, ValueError): + raise ValueError( + f"Invalid hyperparameters file: {self.hyperparameters}. " + "Must be a valid JSON or YAML file." + ) + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None: session = self.sagemaker_session base_job_name = self.base_job_name diff --git a/tests/data/modules/params_script/hyperparameters.json b/tests/data/modules/params_script/hyperparameters.json new file mode 100644 index 0000000000..f637288dbe --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.json @@ -0,0 +1,15 @@ +{ + "integer": 1, + "boolean": true, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "float": 3.14, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": true + } +} \ No newline at end of file diff --git a/tests/data/modules/params_script/hyperparameters.yaml b/tests/data/modules/params_script/hyperparameters.yaml new file mode 100644 index 0000000000..9e3011daf2 --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.yaml @@ -0,0 +1,19 @@ +integer: 1 +boolean: true +float: 3.14 +string: "Hello World" +list: + - 1 + - 2 + - 3 +dict: + string: value + integer: 3 + float: 3.14 + list: + - 1 + - 2 + - 3 + dict: + key: value + boolean: true \ No newline at end of file diff --git a/tests/data/modules/params_script/requirements.txt b/tests/data/modules/params_script/requirements.txt new file mode 100644 index 0000000000..3d2e72e354 --- /dev/null +++ b/tests/data/modules/params_script/requirements.txt @@ -0,0 +1 @@ +omegaconf diff --git a/tests/data/modules/params_script/train.py b/tests/data/modules/params_script/train.py index 8d3924a325..9b8cb2c82f 100644 --- a/tests/data/modules/params_script/train.py +++ b/tests/data/modules/params_script/train.py @@ -16,6 +16,9 @@ import argparse import json import os +from typing import List, Dict, Any +from dataclasses import dataclass +from omegaconf import OmegaConf EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -26,6 +29,7 @@ "dict": { "string": "value", "integer": 3, + "float": 3.14, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": True, @@ -117,7 +121,7 @@ def main(): assert isinstance(params["dict"], dict) params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"] - print(params) + print(f"SM_TRAINING_ENV -> hyperparameters: {params}") assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] @@ -132,9 +136,96 @@ def main(): assert isinstance(params["float"], float) assert isinstance(params["list"], list) assert isinstance(params["dict"], dict) - print(f"SM_TRAINING_ENV -> hyperparameters: {params}") - print("Test passed.") + # Local JSON - DictConfig OmegaConf + params = OmegaConf.load("hyperparameters.json") + + print(f"Local hyperparameters.json: {params}") + assert params.string == EXPECTED_HYPERPARAMETERS["string"] + assert params.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert params.float == EXPECTED_HYPERPARAMETERS["float"] + assert params.list == EXPECTED_HYPERPARAMETERS["list"] + assert params.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + @dataclass + class DictConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: Dict[str, Any] + + @dataclass + class HPConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: DictConfig + + # Local JSON - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json") + ) + print(f"Local hyperparameters.json - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + # Local YAML - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml") + ) + print(f"Local hyperparameters.yaml - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"hyperparameters.yaml -> hyperparameters: {hp_config}") + + # HP Dict - Structured OmegaConf + hp_dict = json.loads(os.environ["SM_HPS"]) + hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict)) + print(f"SM_HPS - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"SM_HPS -> hyperparameters: {hp_config}") if __name__ == "__main__": diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index cd298402b2..a19f6d0e8b 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -28,26 +28,29 @@ "dict": { "string": "value", "integer": 3, + "float": 3.14, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": True, }, } +PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script" +PARAM_SCRIPT_SOURCE_CODE = SourceCode( + source_dir=PARAM_SCRIPT_SOURCE_DIR, + requirements="requirements.txt", + entry_script="train.py", +) + DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" def test_hp_contract_basic_py_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) - model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, base_job_name="hp-contract-basic-py-script", ) @@ -57,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session): def test_hp_contract_basic_sh_script(modules_sagemaker_session): source_code = SourceCode( source_dir=f"{DATA_DIR}/modules/params_script", + requirements="requirements.txt", entry_script="train.sh", ) model_trainer = ModelTrainer( @@ -71,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session): def test_hp_contract_mpi_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, distributed=MPI(), base_job_name="hp-contract-mpi-script", ) @@ -90,19 +90,39 @@ def test_hp_contract_mpi_script(modules_sagemaker_session): def test_hp_contract_torchrun_script(modules_sagemaker_session): - source_code = SourceCode( - source_dir=f"{DATA_DIR}/modules/params_script", - entry_script="train.py", - ) compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) model_trainer = ModelTrainer( sagemaker_session=modules_sagemaker_session, training_image=DEFAULT_CPU_IMAGE, compute=compute, hyperparameters=EXPECTED_HYPERPARAMETERS, - source_code=source_code, + source_code=PARAM_SCRIPT_SOURCE_CODE, distributed=Torchrun(), base_job_name="hp-contract-torchrun-script", ) model_trainer.train() + + +def test_hp_contract_hyperparameter_json(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-json", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-yaml", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 29da03bcd9..194bb44988 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -17,9 +17,10 @@ import tempfile import json import os +import yaml import pytest from pydantic import ValidationError -from unittest.mock import patch, MagicMock, ANY +from unittest.mock import patch, MagicMock, ANY, mock_open from sagemaker import image_uris from sagemaker_core.main.resources import TrainingJob @@ -1093,3 +1094,93 @@ def test_destructor_cleanup(mock_tmp_dir, modules_session): mock_tmp_dir.assert_not_called() del model_trainer mock_tmp_dir.cleanup.assert_called_once() + + +@patch("os.path.exists") +def test_hyperparameters_valid_json(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.json", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.json", "r") + mock_exists.assert_called_once_with("hyperparameters.json") + + +@patch("os.path.exists") +def test_hyperparameters_valid_yaml(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.yaml", "r") + mock_exists.assert_called_once_with("hyperparameters.yaml") + + +def test_hyperparameters_not_exist(modules_session): + with pytest.raises(ValueError): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="nonexistent.json", + ) + + +@patch("os.path.exists") +def test_hyperparameters_invalid(mock_exists, modules_session): + mock_exists.return_value = True + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="- item1\n- item2") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # Must be valid YAML + mock_file_open = mock_open(read_data="* invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + )