@@ -181,6 +181,7 @@ def __init__(
181
181
container_arguments : Optional [List [str ]] = None ,
182
182
disable_output_compression : bool = False ,
183
183
enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
184
+ enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
184
185
** kwargs ,
185
186
):
186
187
"""Initialize an ``EstimatorBase`` instance.
@@ -544,7 +545,9 @@ def __init__(
544
545
enable_infra_check (bool or PipelineVariable): Optional.
545
546
Specifies whether it is running Sagemaker built-in infra check jobs.
546
547
enable_remote_debug (bool or PipelineVariable): Optional.
547
- Specifies whether RemoteDebug is enabled for the training job
548
+ Specifies whether RemoteDebug is enabled for the training job.
549
+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
550
+ Specifies whether SessionTagChaining is enabled for the training job.
548
551
"""
549
552
instance_count = renamed_kwargs (
550
553
"train_instance_count" , "instance_count" , instance_count , kwargs
@@ -785,6 +788,8 @@ def __init__(
785
788
786
789
self ._enable_remote_debug = enable_remote_debug
787
790
791
+ self ._enable_session_tag_chaining = enable_session_tag_chaining
792
+
788
793
@abstractmethod
789
794
def training_image_uri (self ):
790
795
"""Return the Docker image to use for training.
@@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
2318
2323
else {"EnableRemoteDebug" : self ._enable_remote_debug }
2319
2324
)
2320
2325
2326
+ def get_session_chaining_config (self ):
2327
+ """dict: Return the configuration of SessionChaining"""
2328
+ return (
2329
+ None
2330
+ if self ._enable_session_tag_chaining is None
2331
+ else {"EnableSessionTagChaining" : self ._enable_session_tag_chaining }
2332
+ )
2333
+
2321
2334
def enable_remote_debug (self ):
2322
2335
"""Enable remote debug for a training job."""
2323
2336
self ._update_remote_debug (True )
@@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
2574
2587
if estimator .get_remote_debug_config () is not None :
2575
2588
train_args ["remote_debug_config" ] = estimator .get_remote_debug_config ()
2576
2589
2590
+ if estimator .get_session_chaining_config () is not None :
2591
+ train_args ["session_chaining_config" ] = estimator .get_session_chaining_config ()
2592
+
2577
2593
return train_args
2578
2594
2579
2595
@classmethod
@@ -2766,6 +2782,7 @@ def __init__(
2766
2782
disable_output_compression : bool = False ,
2767
2783
enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
2768
2784
enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
2785
+ enable_session_tag_chaining : Optional [Union [bool , PipelineVariable ]] = None ,
2769
2786
** kwargs ,
2770
2787
):
2771
2788
"""Initialize an ``Estimator`` instance.
@@ -3129,6 +3146,8 @@ def __init__(
3129
3146
Specifies whether it is running Sagemaker built-in infra check jobs.
3130
3147
enable_remote_debug (bool or PipelineVariable): Optional.
3131
3148
Specifies whether RemoteDebug is enabled for the training job
3149
+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
3150
+ Specifies whether SessionTagChaining is enabled for the training job
3132
3151
"""
3133
3152
self .image_uri = image_uri
3134
3153
self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3181,6 +3200,7 @@ def __init__(
3181
3200
container_arguments = container_arguments ,
3182
3201
disable_output_compression = disable_output_compression ,
3183
3202
enable_remote_debug = enable_remote_debug ,
3203
+ enable_session_tag_chaining = enable_session_tag_chaining ,
3184
3204
** kwargs ,
3185
3205
)
3186
3206
0 commit comments