Skip to content

Commit 0a6c298

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
ModelBuilder: Add functionalities to get and set deployment config. (#4614)
* Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage --------- Co-authored-by: Jonathan Makunga <makung@amazon.com>
1 parent b92aa3c commit 0a6c298

File tree

9 files changed

+348
-69
lines changed

9 files changed

+348
-69
lines changed

src/sagemaker/jumpstart/model.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import absolute_import
1616

1717
from functools import lru_cache
18-
from typing import Dict, List, Optional, Union, Any
18+
from typing import Dict, List, Optional, Any, Union
1919
import pandas as pd
2020
from botocore.exceptions import ClientError
2121

@@ -441,14 +441,23 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
441441
model_id=self.model_id, model_version=self.model_version, config_name=config_name
442442
)
443443

444+
@property
445+
def deployment_config(self) -> Optional[Dict[str, Any]]:
446+
"""The deployment config that will be applied to the model.
447+
448+
Returns:
449+
Optional[Dict[str, Any]]: Deployment config that will be applied to the model.
450+
"""
451+
return self._retrieve_selected_deployment_config(self.config_name)
452+
444453
@property
445454
def benchmark_metrics(self) -> pd.DataFrame:
446455
"""Benchmark Metrics for deployment configs
447456
448457
Returns:
449458
Metrics: Pandas DataFrame object.
450459
"""
451-
return pd.DataFrame(self._get_benchmark_data(self.config_name))
460+
return pd.DataFrame(self._get_benchmarks_data(self.config_name))
452461

453462
def display_benchmark_metrics(self) -> None:
454463
"""Display Benchmark Metrics for deployment configs."""
@@ -851,8 +860,8 @@ def register_deploy_wrapper(*args, **kwargs):
851860
return model_package
852861

853862
@lru_cache
854-
def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
855-
"""Constructs deployment configs benchmark data.
863+
def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
864+
"""Deployment configs benchmark metrics.
856865
857866
Args:
858867
config_name (str): The name of the selected deployment config.
@@ -864,6 +873,23 @@ def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
864873
config_name,
865874
)
866875

876+
@lru_cache
877+
def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]:
878+
"""Retrieve the deployment config to apply to the model.
879+
880+
Args:
881+
config_name (str): The name of the deployment config to retrieve.
882+
Returns:
883+
Optional[Dict[str, Any]]: The retrieved deployment config.
884+
"""
885+
if config_name is None:
886+
return None
887+
888+
for deployment_config in self._deployment_configs:
889+
if deployment_config.get("DeploymentConfigName") == config_name:
890+
return deployment_config
891+
return None
892+
867893
def _convert_to_deployment_config_metadata(
868894
self, config_name: str, metadata_config: JumpStartMetadataConfig
869895
) -> Dict[str, Any]:

src/sagemaker/jumpstart/types.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -2251,17 +2251,17 @@ def to_json(self) -> Dict[str, Any]:
22512251
return json_obj
22522252

22532253

2254-
class DeploymentConfig(BaseDeploymentConfigDataHolder):
2254+
class DeploymentArgs(BaseDeploymentConfigDataHolder):
22552255
"""Dataclass representing a Deployment Config."""
22562256

22572257
__slots__ = [
2258-
"model_data_download_timeout",
2259-
"container_startup_health_check_timeout",
22602258
"image_uri",
22612259
"model_data",
2262-
"instance_type",
22632260
"environment",
2261+
"instance_type",
22642262
"compute_resource_requirements",
2263+
"model_data_download_timeout",
2264+
"container_startup_health_check_timeout",
22652265
]
22662266

22672267
def __init__(
@@ -2288,9 +2288,10 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
22882288
"""Dataclass representing a Deployment Config Metadata"""
22892289

22902290
__slots__ = [
2291-
"config_name",
2291+
"deployment_config_name",
2292+
"deployment_args",
2293+
"acceleration_configs",
22922294
"benchmark_metrics",
2293-
"deployment_config",
22942295
]
22952296

22962297
def __init__(
@@ -2301,6 +2302,7 @@ def __init__(
23012302
deploy_kwargs: JumpStartModelDeployKwargs,
23022303
):
23032304
"""Instantiates DeploymentConfigMetadata object."""
2304-
self.config_name = config_name
2305+
self.deployment_config_name = config_name
2306+
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs)
2307+
self.acceleration_configs = None
23052308
self.benchmark_metrics = benchmark_metrics
2306-
self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs)

src/sagemaker/jumpstart/utils.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1040,24 +1040,40 @@ def extract_metrics_from_deployment_configs(
10401040
config_name (str): The name of the deployment config use by the model.
10411041
"""
10421042

1043-
data = {"Config Name": [], "Instance Type": [], "Selected": []}
1043+
data = {"Config Name": [], "Instance Type": [], "Selected": [], "Accelerated": []}
10441044

10451045
for index, deployment_config in enumerate(deployment_configs):
1046-
if deployment_config.get("DeploymentConfig") is None:
1046+
if deployment_config.get("DeploymentArgs") is None:
10471047
continue
10481048

10491049
benchmark_metrics = deployment_config.get("BenchmarkMetrics")
10501050
if benchmark_metrics is not None:
1051-
data["Config Name"].append(deployment_config.get("ConfigName"))
1051+
data["Config Name"].append(deployment_config.get("DeploymentConfigName"))
10521052
data["Instance Type"].append(
1053-
deployment_config.get("DeploymentConfig").get("InstanceType")
1053+
deployment_config.get("DeploymentArgs").get("InstanceType")
10541054
)
10551055
data["Selected"].append(
10561056
"Yes"
1057-
if (config_name is not None and config_name == deployment_config.get("ConfigName"))
1057+
if (
1058+
config_name is not None
1059+
and config_name == deployment_config.get("DeploymentConfigName")
1060+
)
10581061
else "No"
10591062
)
10601063

1064+
accelerated_configs = deployment_config.get("AccelerationConfigs")
1065+
if accelerated_configs is None:
1066+
data["Accelerated"].append("No")
1067+
else:
1068+
data["Accelerated"].append(
1069+
"Yes"
1070+
if (
1071+
len(accelerated_configs) > 0
1072+
and accelerated_configs[0].get("Enabled", False)
1073+
)
1074+
else "No"
1075+
)
1076+
10611077
if index == 0:
10621078
for benchmark_metric in benchmark_metrics:
10631079
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
@@ -1068,4 +1084,6 @@ def extract_metrics_from_deployment_configs(
10681084
if column_name in data.keys():
10691085
data[column_name].append(benchmark_metric.get("value"))
10701086

1087+
if "Yes" not in data["Accelerated"]:
1088+
del data["Accelerated"]
10711089
return data

src/sagemaker/serve/builder/jumpstart_builder.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import copy
1717
from abc import ABC, abstractmethod
1818
from datetime import datetime, timedelta
19-
from typing import Type, Any, List, Dict
19+
from typing import Type, Any, List, Dict, Optional
2020
import logging
2121

2222
from sagemaker.model import Model
@@ -431,8 +431,35 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
431431
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
432432
)
433433

434+
def set_deployment_config(self, config_name: Optional[str]) -> None:
435+
"""Sets the deployment config to apply to the model.
436+
437+
Args:
438+
config_name (Optional[str]):
439+
The name of the deployment config. Set to None to unset
440+
any existing config that is applied to the model.
441+
"""
442+
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
443+
raise Exception("Cannot set deployment config to an uninitialized model.")
444+
445+
self.pysdk_model.set_deployment_config(config_name)
446+
447+
def get_deployment_config(self) -> Optional[Dict[str, Any]]:
448+
"""Gets the deployment config to apply to the model.
449+
450+
Returns:
451+
Optional[Dict[str, Any]]: Deployment config to apply to this model.
452+
"""
453+
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
454+
self.pysdk_model = self._create_pre_trained_js_model()
455+
456+
return self.pysdk_model.deployment_config
457+
434458
def display_benchmark_metrics(self):
435459
"""Display Markdown Benchmark Metrics for deployment configs."""
460+
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
461+
self.pysdk_model = self._create_pre_trained_js_model()
462+
436463
self.pysdk_model.display_benchmark_metrics()
437464

438465
def list_deployment_configs(self) -> List[Dict[str, Any]]:
@@ -441,6 +468,9 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]:
441468
Returns:
442469
List[Dict[str, Any]]: A list of deployment configs.
443470
"""
471+
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
472+
self.pysdk_model = self._create_pre_trained_js_model()
473+
444474
return self.pysdk_model.list_deployment_configs()
445475

446476
def _build_for_jumpstart(self):
@@ -449,32 +479,29 @@ def _build_for_jumpstart(self):
449479
self.secret_key = None
450480
self.jumpstart = True
451481

452-
pysdk_model = self._create_pre_trained_js_model()
453-
454-
image_uri = pysdk_model.image_uri
482+
if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
483+
self.pysdk_model = self._create_pre_trained_js_model()
455484

456-
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
485+
logger.info(
486+
"JumpStart ID %s is packaged with Image URI: %s", self.model, self.pysdk_model.image_uri
487+
)
457488

458-
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
489+
if self._is_gated_model() and self.mode != Mode.SAGEMAKER_ENDPOINT:
459490
raise ValueError(
460491
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
461492
)
462493

463-
if "djl-inference" in image_uri:
494+
if "djl-inference" in self.pysdk_model.image_uri:
464495
logger.info("Building for DJL JumpStart Model ID...")
465496
self.model_server = ModelServer.DJL_SERVING
466-
467-
self.pysdk_model = pysdk_model
468497
self.image_uri = self.pysdk_model.image_uri
469498

470499
self._build_for_djl_jumpstart()
471500

472501
self.pysdk_model.tune = self.tune_for_djl_jumpstart
473-
elif "tgi-inference" in image_uri:
502+
elif "tgi-inference" in self.pysdk_model.image_uri:
474503
logger.info("Building for TGI JumpStart Model ID...")
475504
self.model_server = ModelServer.TGI
476-
477-
self.pysdk_model = pysdk_model
478505
self.image_uri = self.pysdk_model.image_uri
479506

480507
self._build_for_tgi_jumpstart()
@@ -487,15 +514,13 @@ def _build_for_jumpstart(self):
487514

488515
return self.pysdk_model
489516

490-
def _is_gated_model(self, model) -> bool:
517+
def _is_gated_model(self) -> bool:
491518
"""Determine if ``this`` Model is Gated
492519
493-
Args:
494-
model (Model): Jumpstart Model
495520
Returns:
496521
bool: ``True`` if ``this`` Model is Gated
497522
"""
498-
s3_uri = model.model_data
523+
s3_uri = self.pysdk_model.model_data
499524
if isinstance(s3_uri, dict):
500525
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
501526

0 commit comments

Comments
 (0)