Skip to content

Add ReadOnly APIs #4606

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 17 commits into from
Apr 24, 2024
109 changes: 106 additions & 3 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from __future__ import absolute_import

from typing import Dict, List, Optional, Union
from functools import lru_cache
from typing import Dict, List, Optional, Union, Any
import pandas as pd
from botocore.exceptions import ClientError

from sagemaker import payloads
Expand All @@ -36,14 +38,21 @@
get_init_kwargs,
get_register_kwargs,
)
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.types import (
JumpStartSerializablePayload,
DeploymentConfigMetadata,
JumpStartBenchmarkStat,
JumpStartMetadataConfig,
)
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
verify_model_region_and_return_specs,
get_jumpstart_configs,
extract_metrics_from_deployment_configs,
)
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.utils import stringify_object, format_tags, Tags
from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour
from sagemaker.model import (
Model,
ModelPackage,
Expand Down Expand Up @@ -352,6 +361,18 @@ def _validate_model_id_and_type():
self.model_package_arn = model_init_kwargs.model_package_arn
self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)

metadata_configs = get_jumpstart_configs(
region=self.region,
model_id=self.model_id,
model_version=self.model_version,
sagemaker_session=self.sagemaker_session,
model_type=self.model_type,
)
self._deployment_configs = [
self._convert_to_deployment_config_metadata(config_name, config)
for config_name, config in metadata_configs.items()
]

def log_subscription_warning(self) -> None:
"""Log message prompting the customer to subscribe to the proprietary model."""
subscription_link = verify_model_region_and_return_specs(
Expand Down Expand Up @@ -420,6 +441,27 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
model_id=self.model_id, model_version=self.model_version, config_name=config_name
)

@property
def benchmark_metrics(self) -> pd.DataFrame:
"""Benchmark Metrics for deployment configs

Returns:
Metrics: Pandas DataFrame object.
"""
return pd.DataFrame(self._get_benchmark_data(self.config_name))

def display_benchmark_metrics(self) -> None:
"""Display Benchmark Metrics for deployment configs."""
print(self.benchmark_metrics.to_markdown())

def list_deployment_configs(self) -> List[Dict[str, Any]]:
"""List deployment configs for ``This`` model.

Returns:
List[Dict[str, Any]]: A list of deployment configs.
"""
return self._deployment_configs

def _create_sagemaker_model(
self,
instance_type=None,
Expand Down Expand Up @@ -808,6 +850,67 @@ def register_deploy_wrapper(*args, **kwargs):

return model_package

@lru_cache
def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
"""Constructs deployment configs benchmark data.

Args:
config_name (str): The name of the selected deployment config.
Returns:
Dict[str, List[str]]: Deployment config benchmark data.
"""
return extract_metrics_from_deployment_configs(
self._deployment_configs,
config_name,
)

def _convert_to_deployment_config_metadata(
self, config_name: str, metadata_config: JumpStartMetadataConfig
) -> Dict[str, Any]:
"""Retrieve deployment config for config name.

Args:
config_name (str): Name of deployment config.
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.
Returns:
A deployment metadata config for config name (dict[str, Any]).
"""
default_inference_instance_type = metadata_config.resolved_config.get(
"default_inference_instance_type"
)

instance_rate = get_instance_rate_per_hour(
instance_type=default_inference_instance_type, region=self.region
)

benchmark_metrics = (
metadata_config.benchmark_metrics.get(default_inference_instance_type)
if metadata_config.benchmark_metrics is not None
else None
)
if instance_rate is not None:
if benchmark_metrics is not None:
benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate))
else:
benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)]

init_kwargs = get_init_kwargs(
model_id=self.model_id,
instance_type=default_inference_instance_type,
sagemaker_session=self.sagemaker_session,
)
deploy_kwargs = get_deploy_kwargs(
model_id=self.model_id,
instance_type=default_inference_instance_type,
sagemaker_session=self.sagemaker_session,
)

deployment_config_metadata = DeploymentConfigMetadata(
config_name, benchmark_metrics, init_kwargs, deploy_kwargs
)

return deployment_config_metadata.to_json()

def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
return stringify_object(self)
96 changes: 96 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,3 +2206,99 @@ def __init__(
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri
self.config_name = config_name


class BaseDeploymentConfigDataHolder(JumpStartDataHolderType):
"""Base class for Deployment Config Data."""

def _convert_to_pascal_case(self, attr_name: str) -> str:
"""Converts a snake_case attribute name into a camelCased string.

Args:
attr_name (str): The snake_case attribute name.
Returns:
str: The PascalCased attribute name.
"""
return attr_name.replace("_", " ").title().replace(" ", "")

def to_json(self) -> Dict[str, Any]:
"""Represents ``This`` object as JSON."""
json_obj = {}
for att in self.__slots__:
if hasattr(self, att):
cur_val = getattr(self, att)
att = self._convert_to_pascal_case(att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
elif isinstance(cur_val, list):
json_obj[att] = []
for obj in cur_val:
if issubclass(type(obj), JumpStartDataHolderType):
Copy link
Member

@evakravi evakravi Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic's really complicated. if you can find a way to reduce indentation level, that'd improve readability

json_obj[att].append(obj.to_json())
else:
json_obj[att].append(obj)
elif isinstance(cur_val, dict):
json_obj[att] = {}
for key, val in cur_val.items():
if issubclass(type(val), JumpStartDataHolderType):
json_obj[att][self._convert_to_pascal_case(key)] = val.to_json()
else:
json_obj[att][key] = val
else:
json_obj[att] = cur_val
return json_obj


class DeploymentConfig(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config."""

__slots__ = [
"model_data_download_timeout",
"container_startup_health_check_timeout",
"image_uri",
"model_data",
"instance_type",
"environment",
"compute_resource_requirements",
]

def __init__(
self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs
):
"""Instantiates DeploymentConfig object."""
if init_kwargs is not None:
self.image_uri = init_kwargs.image_uri
self.model_data = init_kwargs.model_data
self.instance_type = init_kwargs.instance_type
self.environment = init_kwargs.env
if init_kwargs.resources is not None:
self.compute_resource_requirements = (
init_kwargs.resources.get_compute_resource_requirements()
)
if deploy_kwargs is not None:
self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout
self.container_startup_health_check_timeout = (
deploy_kwargs.container_startup_health_check_timeout
)


class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
"""Dataclass representing a Deployment Config Metadata"""

__slots__ = [
"config_name",
"benchmark_metrics",
"deployment_config",
]

def __init__(
self,
config_name: str,
benchmark_metrics: List[JumpStartBenchmarkStat],
init_kwargs: JumpStartModelInitKwargs,
deploy_kwargs: JumpStartModelDeployKwargs,
):
"""Instantiates DeploymentConfigMetadata object."""
self.config_name = config_name
self.benchmark_metrics = benchmark_metrics
self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs)
41 changes: 41 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,3 +999,44 @@ def get_jumpstart_configs(
if metadata_configs
else {}
)


def extract_metrics_from_deployment_configs(
deployment_configs: List[Dict[str, Any]], config_name: str
) -> Dict[str, List[str]]:
"""Extracts metrics from deployment configs.

Args:
deployment_configs (list[dict[str, Any]]): List of deployment configs.
config_name (str): The name of the deployment config use by the model.
"""

data = {"Config Name": [], "Instance Type": [], "Selected": []}

for index, deployment_config in enumerate(deployment_configs):
if deployment_config.get("DeploymentConfig") is None:
continue

benchmark_metrics = deployment_config.get("BenchmarkMetrics")
if benchmark_metrics is not None:
data["Config Name"].append(deployment_config.get("ConfigName"))
data["Instance Type"].append(
deployment_config.get("DeploymentConfig").get("InstanceType")
)
data["Selected"].append(
"Yes"
if (config_name is not None and config_name == deployment_config.get("ConfigName"))
else "No"
)

if index == 0:
for benchmark_metric in benchmark_metrics:
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
data[column_name] = []

for benchmark_metric in benchmark_metrics:
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
if column_name in data.keys():
data[column_name].append(benchmark_metric.get("value"))

return data
14 changes: 13 additions & 1 deletion src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Type
from typing import Type, Any, List, Dict
import logging

from sagemaker.model import Model
Expand Down Expand Up @@ -431,6 +431,18 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)

def display_benchmark_metrics(self):
"""Display Markdown Benchmark Metrics for deployment configs."""
self.pysdk_model.display_benchmark_metrics()

def list_deployment_configs(self) -> List[Dict[str, Any]]:
"""List deployment configs for ``This`` model in the current region.

Returns:
List[Dict[str, Any]]: A list of deployment configs.
"""
return self.pysdk_model.list_deployment_configs()

def _build_for_jumpstart(self):
"""Placeholder docstring"""
# we do not pickle for jumpstart. set to none
Expand Down
Loading