Skip to content

Commit dc5ac8f

Browse files
committed
Allow ModelTrainer to accept hyperparameter file and create Hyperparameter class
1 parent b116e2f commit dc5ac8f

File tree

2 files changed

+114
-3
lines changed

2 files changed

+114
-3
lines changed

Diff for: src/sagemaker/modules/hyperparameters.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import os
2+
import json
3+
import dataclasses
4+
from typing import Any, Type, TypeVar
5+
6+
from sagemaker.modules import logger
7+
8+
T = TypeVar("T")
9+
10+
11+
class DictConfig:
12+
"""Class that supports both dict and dot notation access"""
13+
14+
def __init__(self, **kwargs):
15+
# Store the original dict
16+
self._data = kwargs
17+
18+
# Set all items as attributes for dot notation
19+
for key, value in kwargs.items():
20+
# Recursively convert nested dicts to DictConfig
21+
if isinstance(value, dict):
22+
value = DictConfig(**value)
23+
setattr(self, key, value)
24+
25+
def __getitem__(self, key: str) -> Any:
26+
"""Enable dictionary-style access: config['key']"""
27+
return self._data[key]
28+
29+
def __setitem__(self, key: str, value: Any):
30+
"""Enable dictionary-style assignment: config['key'] = value"""
31+
self._data[key] = value
32+
setattr(self, key, value)
33+
34+
def __str__(self) -> str:
35+
"""String representation"""
36+
return str(self._data)
37+
38+
def __repr__(self) -> str:
39+
"""Detailed string representation"""
40+
return f"DictConfig({self._data})"
41+
42+
43+
class Hyperparameters:
44+
"""Class to load hyperparameters in training container."""
45+
46+
@staticmethod
47+
def load() -> DictConfig:
48+
"""Loads hyperparameters in training container
49+
50+
Example:
51+
52+
.. code:: python
53+
from sagemaker.modules.hyperparameters import Hyperparameters
54+
55+
hps = Hyperparameters.load()
56+
print(hps.batch_size)
57+
58+
Returns:
59+
DictConfig: hyperparameters as a DictConfig object
60+
"""
61+
hps = json.loads(os.environ.get("SM_HPS", "{}"))
62+
if not hps:
63+
logger.warning("No hyperparameters found in SM_HPS environment variable.")
64+
return DictConfig(**hps)
65+
66+
@staticmethod
67+
def load_structured(dataclass_type: Type[T]) -> T:
68+
"""Loads hyperparameters as a structured dataclass
69+
70+
Example:
71+
72+
.. code:: python
73+
from sagemaker.modules.hyperparameters import Hyperparameters
74+
75+
@dataclass
76+
class TrainingConfig:
77+
batch_size: int
78+
learning_rate: float
79+
80+
config = Hyperparameters.load_structured(TrainingConfig)
81+
print(config.batch_size) # typed int
82+
83+
Args:
84+
dataclass_type: Dataclass type to structure the config
85+
86+
Returns:
87+
dataclass_type: Instance of provided dataclass type
88+
"""
89+
90+
if not dataclasses.is_dataclass(dataclass_type):
91+
raise ValueError(f"{dataclass_type} is not a dataclass type.")
92+
93+
hps = json.loads(os.environ.get("SM_HPS", "{}"))
94+
if not hps:
95+
logger.warning("No hyperparameters found in SM_HPS environment variable.")
96+
97+
# Convert hyperparameters to dataclass
98+
return dataclass_type(**hps)

Diff for: src/sagemaker/modules/train/model_trainer.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import json
1919
import shutil
20+
import yaml
2021
from tempfile import TemporaryDirectory
2122

2223
from typing import Optional, List, Union, Dict, Any, ClassVar
@@ -195,8 +196,9 @@ class ModelTrainer(BaseModel):
195196
Defaults to "File".
196197
environment (Optional[Dict[str, str]]):
197198
The environment variables for the training job.
198-
hyperparameters (Optional[Dict[str, Any]]):
199-
The hyperparameters for the training job.
199+
hyperparameters (Optional[Union[Dict[str, Any], str]):
200+
The hyperparameters for the training job. Can be a dictionary of hyperparameters
201+
or a path to hyperparameters json/yaml file.
200202
tags (Optional[List[Tag]]):
201203
An array of key-value pairs. You can use tags to categorize your AWS resources
202204
in different ways, for example, by purpose, owner, or environment.
@@ -226,7 +228,7 @@ class ModelTrainer(BaseModel):
226228
checkpoint_config: Optional[CheckpointConfig] = None
227229
training_input_mode: Optional[str] = "File"
228230
environment: Optional[Dict[str, str]] = {}
229-
hyperparameters: Optional[Dict[str, Any]] = {}
231+
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
230232
tags: Optional[List[Tag]] = None
231233
local_container_root: Optional[str] = os.getcwd()
232234

@@ -470,6 +472,17 @@ def model_post_init(self, __context: Any):
470472
f"StoppingCondition not provided. Using default:\n{self.stopping_condition}"
471473
)
472474

475+
if self.hyperparameters and isinstance(self.hyperparameters, str):
476+
if not os.path.exists(self.hyperparameters):
477+
raise ValueError(f"Hyperparameters file not found: {self.hyperparameters}")
478+
logger.info(f"Loading hyperparameters from file: {self.hyperparameters}")
479+
if self.hyperparameters.endswith(".json"):
480+
with open(self.hyperparameters, "r") as f:
481+
self.hyperparameters = json.load(f)
482+
elif self.hyperparameters.endswith(".yaml"):
483+
with open(self.hyperparameters, "r") as f:
484+
self.hyperparameters = yaml.safe_load(f)
485+
473486
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None:
474487
session = self.sagemaker_session
475488
base_job_name = self.base_job_name

0 commit comments

Comments
 (0)