Skip to content

Commit 3f1d2de

Browse files
Narrohagsagemaker-botbeniericvarunmorisgwang111
authored andcommitted
feature: support training for JumpStart model references as part of Curated Hub Phase 2 (aws#5070)
* change: update image_uri_configs 01-27-2025 06:18:13 PST * fix: skip TF tests for unsupported versions (aws#5007) * fix: skip TF tests for unsupported versions * flake8 * change: update image_uri_configs 01-29-2025 06:18:08 PST * chore: add new images for HF TGI (aws#5005) * feat: add pytorch-tgi-inference 2.4.0 * add tgi 3.0.1 image * skip faulty test * formatting * formatting * add hf pytorch training 4.46 * update version alias * add py311 to training version * update tests with pyversion 311 * formatting --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * feat: use jumpstart deployment config image as default optimization image (aws#4992) Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * prepare release v2.238.0 * update development version to v2.238.1.dev0 * Fix ssh host policy (aws#4966) * Fix ssh host policy * Filter policy by algo- * Add docstring * Fix pylint * Fix docstyle summary * Unit test * Fix unit test * Change to unit test * Fix unit tests * Test comment out flaky tests * Readd the flaky tests * Remove flaky asserts * Remove flaky asserts --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * change: Allow telemetry only in supported regions (aws#5009) * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions * change: Allow telemetry only in supported regions --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com> * mpirun protocol - distributed training with @Remote decorator (aws#4998) * implemented multi-node distribution with @Remote function * completed unit tests * added distributed training with CPU and torchrun * backwards compatibility nproc_per_node * fixing code: permissions for non-root users, integration tests * fixed docstyle * refactor nproc_per_node for backwards compatibility * refactor nproc_per_node for backwards compatibility * pylint fix, newlines * added unit tests for bootstrap_environment remote * added mpirun protocol for distributed training with @Remote decorator * aligned mpi_utils_remote.py to mpi_utils.py for estimator * updated docstring for sagemaker sdk doc --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * feat: Add support for deepseek recipes (aws#5011) * feat: Add support for deeepseek recipes * pylint * add unit test * feat: [JumpStart] Add access configs and training instance type variants artifact uri handling for Curated Hub Phase 2 training integration (aws#1653) * Add access config to training input for Curated Hub Training Integration * Add support to retrieve instance specific training artifact keys * Fix some typos and naming issues * Fix more typos * fix formatting issues with black * modify access config logic so accept_eula is passed into fit * update black formatting * Add more unit tests for passing access configs * fix style errors * fix for failing integ test * fix styles and integ test error * skip blocking integ test * fix formatting * remove env vars when access configs are being used * fix docstyle issue * update usage of access configs, remove conversion of training artifact key to uri * fix styling issues * fix styling issues * fix unit tests * fix adding hubaccessconfig only if hubcontentarn exists * move logic to JumpStartEstimator from Job * Fix styling issues * Remove unused code * fix styling issues * fix unit test failure * fix some formatting, add comments * remove typing for estimator in get_access_configs function * fix circular import dependency * fix styling issues --------- Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> * Always add code channel, regardless of network isolation (aws#1657) * fix formatting issue * fix formatting issue * fix formatting issue * fix tensorflow file --------- Co-authored-by: sagemaker-bot <sagemaker-bot@amazon.com> Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Co-authored-by: varunmoris <176621270+varunmoris@users.noreply.github.com> Co-authored-by: Gary Wang <38331932+gwang111@users.noreply.github.com> Co-authored-by: ci <ci> Co-authored-by: parknate@ <parknate@amazon.com> Co-authored-by: rsareddy0329 <rsareddy0329@gmail.com> Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com> Co-authored-by: Bruno Pistone <brn.pistone@gmail.com>
1 parent ed2c7e7 commit 3f1d2de

File tree

18 files changed

+502
-80
lines changed

18 files changed

+502
-80
lines changed

src/sagemaker/estimator.py

-1
Original file line numberDiff line numberDiff line change
@@ -2550,7 +2550,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25502550
raise ValueError(
25512551
"File URIs are supported in local mode only. Please use a S3 URI instead."
25522552
)
2553-
25542553
config = _Job._load_config(inputs, estimator)
25552554

25562555
current_hyperparameters = estimator.hyperparameters()

src/sagemaker/inputs.py

+30
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(
4343
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
4444
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
4545
shuffle_config: Optional["ShuffleConfig"] = None,
46+
hub_access_config: Optional[dict] = None,
47+
model_access_config: Optional[dict] = None,
4648
):
4749
r"""Create a definition for input data used by an SageMaker training job.
4850
@@ -102,6 +104,13 @@ def __init__(
102104
shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables
103105
shuffling on this channel. See the SageMaker API documentation for more info:
104106
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
107+
hub_access_config (dict): Specify the HubAccessConfig of a
108+
Model Reference for which a training job is being created for.
109+
model_access_config (dict): For models that require a Model Access Config, specify True
110+
or False for to indicate whether model terms of use have been accepted.
111+
The `accept_eula` value must be explicitly defined as `True` in order to
112+
accept the end-user license agreement (EULA) that some
113+
models require. (Default: None).
105114
"""
106115
self.config = {
107116
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
@@ -129,6 +138,27 @@ def __init__(
129138
self.config["TargetAttributeName"] = target_attribute_name
130139
if shuffle_config is not None:
131140
self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed}
141+
self.add_hub_access_config(hub_access_config)
142+
self.add_model_access_config(model_access_config)
143+
144+
def add_hub_access_config(self, hub_access_config=None):
145+
"""Add Hub Access Config to the channel's configuration.
146+
147+
Args:
148+
hub_access_config (dict): The HubAccessConfig to be added to the
149+
channel's configuration.
150+
"""
151+
if hub_access_config is not None:
152+
self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config
153+
154+
def add_model_access_config(self, model_access_config=None):
155+
"""Add Model Access Config to the channel's configuration.
156+
157+
Args:
158+
model_access_config (dict): Whether model terms of use have been accepted.
159+
"""
160+
if model_access_config is not None:
161+
self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config
132162

133163

134164
class ShuffleConfig(object):

src/sagemaker/job.py

+46-9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def stop(self):
6565
@staticmethod
6666
def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
6767
"""Placeholder docstring"""
68+
model_access_config, hub_access_config = _Job._get_access_configs(estimator)
6869
input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
6970
role = (
7071
estimator.sagemaker_session.expand_role(estimator.role)
@@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
9596
validate_uri,
9697
content_type="application/x-sagemaker-model",
9798
input_mode="File",
99+
model_access_config=model_access_config,
100+
hub_access_config=hub_access_config,
98101
)
99102
if model_channel:
100103
input_config = [] if input_config is None else input_config
101104
input_config.append(model_channel)
102105

103-
if estimator.enable_network_isolation():
104-
code_channel = _Job._prepare_channel(
105-
input_config, estimator.code_uri, estimator.code_channel_name, validate_uri
106-
)
106+
code_channel = _Job._prepare_channel(
107+
input_config,
108+
estimator.code_uri,
109+
estimator.code_channel_name,
110+
validate_uri,
111+
)
107112

108-
if code_channel:
109-
input_config = [] if input_config is None else input_config
110-
input_config.append(code_channel)
113+
if code_channel:
114+
input_config = [] if input_config is None else input_config
115+
input_config.append(code_channel)
111116

112117
return {
113118
"input_config": input_config,
@@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
118123
"vpc_config": vpc_config,
119124
}
120125

126+
@staticmethod
127+
def _get_access_configs(estimator):
128+
"""Return access configs from estimator object.
129+
130+
JumpStartEstimator uses access configs which need to be added to the model channel,
131+
so they are passed down to the job level.
132+
133+
Args:
134+
estimator (EstimatorBase): estimator object with access config field if applicable
135+
"""
136+
model_access_config, hub_access_config = None, None
137+
if hasattr(estimator, "model_access_config"):
138+
model_access_config = estimator.model_access_config
139+
if hasattr(estimator, "hub_access_config"):
140+
hub_access_config = estimator.hub_access_config
141+
return model_access_config, hub_access_config
142+
121143
@staticmethod
122144
def _format_inputs_to_input_config(inputs, validate_uri=True):
123145
"""Placeholder docstring"""
@@ -173,6 +195,8 @@ def _format_string_uri_input(
173195
input_mode=None,
174196
compression=None,
175197
target_attribute_name=None,
198+
model_access_config=None,
199+
hub_access_config=None,
176200
):
177201
"""Placeholder docstring"""
178202
s3_input_result = TrainingInput(
@@ -181,6 +205,8 @@ def _format_string_uri_input(
181205
input_mode=input_mode,
182206
compression=compression,
183207
target_attribute_name=target_attribute_name,
208+
model_access_config=model_access_config,
209+
hub_access_config=hub_access_config,
184210
)
185211
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
186212
return s3_input_result
@@ -193,7 +219,11 @@ def _format_string_uri_input(
193219
)
194220
if isinstance(uri_input, str):
195221
return s3_input_result
196-
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
222+
if isinstance(uri_input, (file_input, FileSystemInput)):
223+
return uri_input
224+
if isinstance(uri_input, TrainingInput):
225+
uri_input.add_hub_access_config(hub_access_config=hub_access_config)
226+
uri_input.add_model_access_config(model_access_config=model_access_config)
197227
return uri_input
198228
if is_pipeline_variable(uri_input):
199229
return s3_input_result
@@ -211,6 +241,8 @@ def _prepare_channel(
211241
validate_uri=True,
212242
content_type=None,
213243
input_mode=None,
244+
model_access_config=None,
245+
hub_access_config=None,
214246
):
215247
"""Placeholder docstring"""
216248
if not channel_uri:
@@ -226,7 +258,12 @@ def _prepare_channel(
226258
raise ValueError("Duplicate channel {} not allowed.".format(channel_name))
227259

228260
channel_input = _Job._format_string_uri_input(
229-
channel_uri, validate_uri, content_type, input_mode
261+
channel_uri,
262+
validate_uri,
263+
content_type,
264+
input_mode,
265+
model_access_config=model_access_config,
266+
hub_access_config=hub_access_config,
230267
)
231268
channel = _Job._convert_input_to_channel(channel_name, channel_input)
232269

src/sagemaker/jumpstart/artifacts/model_uris.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_region_fallback,
3030
verify_model_region_and_return_specs,
3131
)
32+
from sagemaker.s3_utils import is_s3_url
3233
from sagemaker.session import Session
3334
from sagemaker.jumpstart.types import JumpStartModelSpecs
3435

@@ -74,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty
7475
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
7576
"""Returns instance specific training artifact key or default one as fallback."""
7677
instance_specific_training_artifact_key: Optional[str] = (
77-
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
78+
model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key(
7879
instance_type=instance_type
7980
)
8081
if instance_type
@@ -185,8 +186,8 @@ def _retrieve_model_uri(
185186
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
186187
or default_jumpstart_bucket
187188
)
188-
189-
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
189+
if not is_s3_url(model_artifact_key):
190+
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
190191

191192
return model_s3_uri
192193

src/sagemaker/jumpstart/estimator.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
validate_model_id_and_get_type,
4242
resolve_model_sagemaker_config_field,
4343
verify_model_region_and_return_specs,
44+
remove_env_var_from_estimator_kwargs_if_accept_eula_present,
45+
get_model_access_config,
46+
get_hub_access_config,
4447
)
4548
from sagemaker.utils import stringify_object, format_tags, Tags
4649
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -619,6 +622,10 @@ def _validate_model_id_and_get_type_hook():
619622
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation
620623
self.config_name = estimator_init_kwargs.config_name
621624
self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False)
625+
# Access configs initialized to None, would be given a value when .fit() is called
626+
# if applicable
627+
self.model_access_config = None
628+
self.hub_access_config = None
622629

623630
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
624631

@@ -629,6 +636,7 @@ def fit(
629636
logs: Optional[str] = None,
630637
job_name: Optional[str] = None,
631638
experiment_config: Optional[Dict[str, str]] = None,
639+
accept_eula: Optional[bool] = None,
632640
) -> None:
633641
"""Start training job by calling base ``Estimator`` class ``fit`` method.
634642
@@ -679,8 +687,16 @@ def fit(
679687
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
680688
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
681689
(Default: None).
690+
accept_eula (bool): For models that require a Model Access Config, specify True or
691+
False to indicate whether model terms of use have been accepted.
692+
The `accept_eula` value must be explicitly defined as `True` in order to
693+
accept the end-user license agreement (EULA) that some
694+
models require. (Default: None).
682695
"""
683-
696+
self.model_access_config = get_model_access_config(accept_eula)
697+
self.hub_access_config = get_hub_access_config(
698+
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
699+
)
684700
estimator_fit_kwargs = get_fit_kwargs(
685701
model_id=self.model_id,
686702
model_version=self.model_version,
@@ -695,7 +711,9 @@ def fit(
695711
tolerate_deprecated_model=self.tolerate_deprecated_model,
696712
sagemaker_session=self.sagemaker_session,
697713
config_name=self.config_name,
714+
hub_access_config=self.hub_access_config,
698715
)
716+
remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula)
699717

700718
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
701719

src/sagemaker/jumpstart/factory/estimator.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
from sagemaker.jumpstart.utils import (
7272
add_hub_content_arn_tags,
7373
add_jumpstart_model_info_tags,
74-
get_eula_message,
7574
get_default_jumpstart_session_with_user_agent_suffix,
7675
get_top_ranked_config_name,
7776
update_dict_if_key_not_present,
@@ -265,6 +264,7 @@ def get_fit_kwargs(
265264
tolerate_deprecated_model: Optional[bool] = None,
266265
sagemaker_session: Optional[Session] = None,
267266
config_name: Optional[str] = None,
267+
hub_access_config: Optional[Dict] = None,
268268
) -> JumpStartEstimatorFitKwargs:
269269
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""
270270

@@ -301,10 +301,32 @@ def get_fit_kwargs(
301301
estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs)
302302
estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs)
303303
estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs)
304+
estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs(
305+
estimator_fit_kwargs, hub_access_config
306+
)
304307

305308
return estimator_fit_kwargs
306309

307310

311+
def _add_hub_access_config_to_kwargs_inputs(
312+
kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None
313+
):
314+
"""Adds HubAccessConfig to kwargs inputs"""
315+
316+
if isinstance(kwargs.inputs, str):
317+
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
318+
elif isinstance(kwargs.inputs, TrainingInput):
319+
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
320+
elif isinstance(kwargs.inputs, dict):
321+
for k, v in kwargs.inputs.items():
322+
if isinstance(v, str):
323+
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
324+
elif isinstance(kwargs.inputs, TrainingInput):
325+
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)
326+
327+
return kwargs
328+
329+
308330
def get_deploy_kwargs(
309331
model_id: str,
310332
model_version: Optional[str] = None,
@@ -668,18 +690,6 @@ def _add_env_to_kwargs(
668690
value,
669691
)
670692

671-
environment = getattr(kwargs, "environment", {}) or {}
672-
if (
673-
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
674-
and str(environment.get("accept_eula", "")).lower() != "true"
675-
):
676-
model_specs = kwargs.specs
677-
if model_specs.is_gated_model():
678-
raise ValueError(
679-
"Need to define ‘accept_eula'='true' within Environment. "
680-
f"{get_eula_message(model_specs, kwargs.region)}"
681-
)
682-
683693
return kwargs
684694

685695

src/sagemaker/jumpstart/types.py

+13
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,19 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
619619
instance_type=instance_type, property_name="artifact_key"
620620
)
621621

622+
def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]:
623+
"""Returns instance specific training artifact key.
624+
625+
Returns None if a model, instance type tuple does not have specific
626+
training artifact key.
627+
"""
628+
629+
return self._get_instance_specific_property(
630+
instance_type=instance_type, property_name="training_artifact_uri"
631+
) or self._get_instance_specific_property(
632+
instance_type=instance_type, property_name="training_artifact_key"
633+
)
634+
622635
def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
623636
"""Returns instance specific resource requirements.
624637

src/sagemaker/jumpstart/utils.py

+41
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,47 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16321632
return neo_bucket
16331633

16341634

1635+
def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
1636+
init_kwargs: dict, accept_eula: Optional[bool]
1637+
):
1638+
"""Remove env vars if access configs are used
1639+
1640+
Args:
1641+
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
1642+
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
1643+
"""
1644+
if accept_eula is not None and init_kwargs["environment"]:
1645+
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]
1646+
1647+
1648+
def get_hub_access_config(hub_content_arn: Optional[str]):
1649+
"""Get hub access config
1650+
1651+
Args:
1652+
hub_content_arn (Optional[bool]): Arn of the model reference hub content
1653+
"""
1654+
if hub_content_arn is not None:
1655+
hub_access_config = {"HubContentArn": hub_content_arn}
1656+
else:
1657+
hub_access_config = None
1658+
1659+
return hub_access_config
1660+
1661+
1662+
def get_model_access_config(accept_eula: Optional[bool]):
1663+
"""Get access configs
1664+
1665+
Args:
1666+
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
1667+
"""
1668+
if accept_eula is not None:
1669+
model_access_config = {"AcceptEula": accept_eula}
1670+
else:
1671+
model_access_config = None
1672+
1673+
return model_access_config
1674+
1675+
16351676
def get_latest_version(versions: List[str]) -> Optional[str]:
16361677
"""Returns the latest version using sem-ver when possible."""
16371678
try:

0 commit comments

Comments
 (0)