Skip to content

Commit fe32d79

Browse files
jessicazhu3Jessica Zhu
and
Jessica Zhu
authoredApr 24, 2024
feature: support session tag chaining for training job (#4596)
* feature: support session tag chaining for training job * fix: resolve typo * fix: resolve typo and build failure * fix: resolve typo and unit test failure --------- Co-authored-by: Jessica Zhu <jessicazhu3@106775307+jessicazhu3@users.noreply.github.com>
1 parent 30c9bf6 commit fe32d79

File tree

7 files changed

+92
-1
lines changed

7 files changed

+92
-1
lines changed
 

‎src/sagemaker/estimator.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
container_arguments: Optional[List[str]] = None,
182182
disable_output_compression: bool = False,
183183
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
184+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
184185
**kwargs,
185186
):
186187
"""Initialize an ``EstimatorBase`` instance.
@@ -544,7 +545,9 @@ def __init__(
544545
enable_infra_check (bool or PipelineVariable): Optional.
545546
Specifies whether it is running Sagemaker built-in infra check jobs.
546547
enable_remote_debug (bool or PipelineVariable): Optional.
547-
Specifies whether RemoteDebug is enabled for the training job
548+
Specifies whether RemoteDebug is enabled for the training job.
549+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
550+
Specifies whether SessionTagChaining is enabled for the training job.
548551
"""
549552
instance_count = renamed_kwargs(
550553
"train_instance_count", "instance_count", instance_count, kwargs
@@ -785,6 +788,8 @@ def __init__(
785788

786789
self._enable_remote_debug = enable_remote_debug
787790

791+
self._enable_session_tag_chaining = enable_session_tag_chaining
792+
788793
@abstractmethod
789794
def training_image_uri(self):
790795
"""Return the Docker image to use for training.
@@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
23182323
else {"EnableRemoteDebug": self._enable_remote_debug}
23192324
)
23202325

2326+
def get_session_chaining_config(self):
2327+
"""dict: Return the configuration of SessionChaining"""
2328+
return (
2329+
None
2330+
if self._enable_session_tag_chaining is None
2331+
else {"EnableSessionTagChaining": self._enable_session_tag_chaining}
2332+
)
2333+
23212334
def enable_remote_debug(self):
23222335
"""Enable remote debug for a training job."""
23232336
self._update_remote_debug(True)
@@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25742587
if estimator.get_remote_debug_config() is not None:
25752588
train_args["remote_debug_config"] = estimator.get_remote_debug_config()
25762589

2590+
if estimator.get_session_chaining_config() is not None:
2591+
train_args["session_chaining_config"] = estimator.get_session_chaining_config()
2592+
25772593
return train_args
25782594

25792595
@classmethod
@@ -2766,6 +2782,7 @@ def __init__(
27662782
disable_output_compression: bool = False,
27672783
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
27682784
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
2785+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
27692786
**kwargs,
27702787
):
27712788
"""Initialize an ``Estimator`` instance.
@@ -3129,6 +3146,8 @@ def __init__(
31293146
Specifies whether it is running Sagemaker built-in infra check jobs.
31303147
enable_remote_debug (bool or PipelineVariable): Optional.
31313148
Specifies whether RemoteDebug is enabled for the training job
3149+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
3150+
Specifies whether SessionTagChaining is enabled for the training job
31323151
"""
31333152
self.image_uri = image_uri
31343153
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3181,6 +3200,7 @@ def __init__(
31813200
container_arguments=container_arguments,
31823201
disable_output_compression=disable_output_compression,
31833202
enable_remote_debug=enable_remote_debug,
3203+
enable_session_tag_chaining=enable_session_tag_chaining,
31843204
**kwargs,
31853205
)
31863206

‎src/sagemaker/jumpstart/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
container_arguments: Optional[List[str]] = None,
110110
disable_output_compression: Optional[bool] = None,
111111
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
112+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
112113
):
113114
"""Initializes a ``JumpStartEstimator``.
114115
@@ -500,6 +501,8 @@ def __init__(
500501
to Amazon S3 without compression after training finishes.
501502
enable_remote_debug (bool or PipelineVariable): Optional.
502503
Specifies whether RemoteDebug is enabled for the training job
504+
enable_session_tag_chaining (bool or PipelineVariable): Optional.
505+
Specifies whether SessionTagChaining is enabled for the training job
503506
504507
Raises:
505508
ValueError: If the model ID is not recognized by JumpStart.
@@ -578,6 +581,7 @@ def _validate_model_id_and_get_type_hook():
578581
disable_output_compression=disable_output_compression,
579582
enable_infra_check=enable_infra_check,
580583
enable_remote_debug=enable_remote_debug,
584+
enable_session_tag_chaining=enable_session_tag_chaining,
581585
)
582586

583587
self.model_id = estimator_init_kwargs.model_id

‎src/sagemaker/jumpstart/factory/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def get_init_kwargs(
130130
disable_output_compression: Optional[bool] = None,
131131
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
132132
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
133+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
133134
) -> JumpStartEstimatorInitKwargs:
134135
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
135136

@@ -188,6 +189,7 @@ def get_init_kwargs(
188189
disable_output_compression=disable_output_compression,
189190
enable_infra_check=enable_infra_check,
190191
enable_remote_debug=enable_remote_debug,
192+
enable_session_tag_chaining=enable_session_tag_chaining,
191193
)
192194

193195
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)

‎src/sagemaker/jumpstart/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
17511751
"disable_output_compression",
17521752
"enable_infra_check",
17531753
"enable_remote_debug",
1754+
"enable_session_tag_chaining",
17541755
]
17551756

17561757
SERIALIZATION_EXCLUSION_SET = {
@@ -1818,6 +1819,7 @@ def __init__(
18181819
disable_output_compression: Optional[bool] = None,
18191820
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
18201821
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
1822+
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
18211823
) -> None:
18221824
"""Instantiates JumpStartEstimatorInitKwargs object."""
18231825

@@ -1877,6 +1879,7 @@ def __init__(
18771879
self.disable_output_compression = disable_output_compression
18781880
self.enable_infra_check = enable_infra_check
18791881
self.enable_remote_debug = enable_remote_debug
1882+
self.enable_session_tag_chaining = enable_session_tag_chaining
18801883

18811884

18821885
class JumpStartEstimatorFitKwargs(JumpStartKwargs):

‎src/sagemaker/session.py

+24
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ def train( # noqa: C901
758758
environment: Optional[Dict[str, str]] = None,
759759
retry_strategy=None,
760760
remote_debug_config=None,
761+
session_chaining_config=None,
761762
):
762763
"""Create an Amazon SageMaker training job.
763764
@@ -877,6 +878,15 @@ def train( # noqa: C901
877878
remote_debug_config = {
878879
"EnableRemoteDebug": True,
879880
}
881+
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
882+
The dict can contain 'EnableSessionTagChaining'(bool).
883+
For example,
884+
885+
.. code:: python
886+
887+
session_chaining_config = {
888+
"EnableSessionTagChaining": True,
889+
}
880890
environment (dict[str, str]) : Environment variables to be set for
881891
use during training job (default: ``None``)
882892
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -970,6 +980,7 @@ def train( # noqa: C901
970980
profiler_rule_configs=profiler_rule_configs,
971981
profiler_config=inferred_profiler_config,
972982
remote_debug_config=remote_debug_config,
983+
session_chaining_config=session_chaining_config,
973984
environment=environment,
974985
retry_strategy=retry_strategy,
975986
)
@@ -1013,6 +1024,7 @@ def _get_train_request( # noqa: C901
10131024
profiler_rule_configs=None,
10141025
profiler_config=None,
10151026
remote_debug_config=None,
1027+
session_chaining_config=None,
10161028
environment=None,
10171029
retry_strategy=None,
10181030
):
@@ -1133,6 +1145,15 @@ def _get_train_request( # noqa: C901
11331145
remote_debug_config = {
11341146
"EnableRemoteDebug": True,
11351147
}
1148+
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
1149+
The dict can contain 'EnableSessionTagChaining'(bool).
1150+
For example,
1151+
1152+
.. code:: python
1153+
1154+
session_chaining_config = {
1155+
"EnableSessionTagChaining": True,
1156+
}
11361157
environment (dict[str, str]) : Environment variables to be set for
11371158
use during training job (default: ``None``)
11381159
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -1239,6 +1260,9 @@ def _get_train_request( # noqa: C901
12391260
if remote_debug_config is not None:
12401261
train_request["RemoteDebugConfig"] = remote_debug_config
12411262

1263+
if session_chaining_config is not None:
1264+
train_request["SessionChainingConfig"] = session_chaining_config
1265+
12421266
if retry_strategy is not None:
12431267
train_request["RetryStrategy"] = retry_strategy
12441268

‎tests/unit/test_estimator.py

+35
Original file line numberDiff line numberDiff line change
@@ -2089,6 +2089,41 @@ def test_framework_disable_remote_debug(sagemaker_session):
20892089
assert len(args) == 2
20902090

20912091

2092+
def test_framework_with_session_chaining_config(sagemaker_session):
2093+
f = DummyFramework(
2094+
entry_point=SCRIPT_PATH,
2095+
role=ROLE,
2096+
sagemaker_session=sagemaker_session,
2097+
instance_groups=[
2098+
InstanceGroup("group1", "ml.c4.xlarge", 1),
2099+
InstanceGroup("group2", "ml.m4.xlarge", 2),
2100+
],
2101+
enable_session_tag_chaining=True,
2102+
)
2103+
f.fit("s3://mydata")
2104+
sagemaker_session.train.assert_called_once()
2105+
_, args = sagemaker_session.train.call_args
2106+
assert args["session_chaining_config"]["EnableSessionTagChaining"]
2107+
assert f.get_session_chaining_config()["EnableSessionTagChaining"]
2108+
2109+
2110+
def test_framework_without_session_chaining_config(sagemaker_session):
2111+
f = DummyFramework(
2112+
entry_point=SCRIPT_PATH,
2113+
role=ROLE,
2114+
sagemaker_session=sagemaker_session,
2115+
instance_groups=[
2116+
InstanceGroup("group1", "ml.c4.xlarge", 1),
2117+
InstanceGroup("group2", "ml.m4.xlarge", 2),
2118+
],
2119+
)
2120+
f.fit("s3://mydata")
2121+
sagemaker_session.train.assert_called_once()
2122+
_, args = sagemaker_session.train.call_args
2123+
assert args.get("SessionTagChaining") is None
2124+
assert f.get_remote_debug_config() is None
2125+
2126+
20922127
@patch("time.strftime", return_value=TIMESTAMP)
20932128
def test_custom_code_bucket(time, sagemaker_session):
20942129
code_bucket = "codebucket"

‎tests/unit/test_session.py

+3
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
21972197
CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"]
21982198
CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"]
21992199
remote_debug_config = {"EnableRemoteDebug": True}
2200+
session_chaining_config = {"EnableSessionTagChaining": True}
22002201

22012202
sagemaker_session.train(
22022203
image_uri=IMAGE,
@@ -2222,6 +2223,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
22222223
container_entry_point=CONTAINER_ENTRY_POINT,
22232224
container_arguments=CONTAINER_ARGUMENTS,
22242225
remote_debug_config=remote_debug_config,
2226+
session_chaining_config=session_chaining_config,
22252227
)
22262228

22272229
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -2245,6 +2247,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
22452247
)
22462248
assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS
22472249
assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"]
2250+
assert actual_train_args["SessionChainingConfig"]["EnableSessionTagChaining"]
22482251

22492252

22502253
def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):

0 commit comments

Comments
 (0)