Skip to content

Commit 1e76572

Browse files
makungaj1Jonathan Makunga
authored andcommitted
Benchmark feature fixes (aws#4632)
* Filter down Benchmark Metrics * Filter down Benchmark Metrics * Testing NB * Testing MB * Testing * Refactoring * Unit tests * Display instance type first, and instance rate last * Display unbalanced metrics * Testing with NB * Testing with NB * Debug * Debug * Testing with NB * Testing with NB * Testing with NB * Refactoring * Refactoring * Refactoring * Unit tests * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Custom lru * Refactoring * Debug * Config ranking * Debug * Debug * Debug * Debug * Debug * Ranking * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Ranking-Debug * Debug * Debug * Debug * Debug * Refactoring * Contact JumpStart team to fix flaky test. test_list_jumpstart_models_script_filter --------- Co-authored-by: Jonathan Makunga <makung@amazon.com>
1 parent c9f4d30 commit 1e76572

File tree

8 files changed

+214
-134
lines changed

8 files changed

+214
-134
lines changed

src/sagemaker/jumpstart/model.py

+34-61
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import absolute_import
1616

17-
from functools import lru_cache
1817
from typing import Dict, List, Optional, Any, Union
1918
import pandas as pd
2019
from botocore.exceptions import ClientError
@@ -48,6 +47,8 @@
4847
get_jumpstart_configs,
4948
get_metrics_from_deployment_configs,
5049
add_instance_rate_stats_to_benchmark_metrics,
50+
deployment_config_response_data,
51+
_deployment_config_lru_cache,
5152
)
5253
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
5354
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -449,10 +450,12 @@ def deployment_config(self) -> Optional[Dict[str, Any]]:
449450
Returns:
450451
Optional[Dict[str, Any]]: Deployment config.
451452
"""
452-
deployment_config = self._retrieve_selected_deployment_config(
453-
self.config_name, self.instance_type
454-
)
455-
return deployment_config.to_json() if deployment_config is not None else None
453+
if self.config_name is None:
454+
return None
455+
for config in self.list_deployment_configs():
456+
if config.get("DeploymentConfigName") == self.config_name:
457+
return config
458+
return None
456459

457460
@property
458461
def benchmark_metrics(self) -> pd.DataFrame:
@@ -461,29 +464,24 @@ def benchmark_metrics(self) -> pd.DataFrame:
461464
Returns:
462465
Benchmark Metrics: Pandas DataFrame object.
463466
"""
464-
benchmark_metrics_data = self._get_deployment_configs_benchmarks_data(
465-
self.config_name, self.instance_type
466-
)
467-
keys = list(benchmark_metrics_data.keys())
468-
df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]])
469-
return df
467+
df = pd.DataFrame(self._get_deployment_configs_benchmarks_data())
468+
default_mask = df.apply(lambda row: any("Default" in str(val) for val in row), axis=1)
469+
sorted_df = pd.concat([df[default_mask], df[~default_mask]])
470+
return sorted_df
470471

471-
def display_benchmark_metrics(self) -> None:
472+
def display_benchmark_metrics(self, *args, **kwargs) -> None:
472473
"""Display deployment configs benchmark metrics."""
473-
print(self.benchmark_metrics.to_markdown(index=False))
474+
print(self.benchmark_metrics.to_markdown(index=False), *args, **kwargs)
474475

475476
def list_deployment_configs(self) -> List[Dict[str, Any]]:
476477
"""List deployment configs for ``This`` model.
477478
478479
Returns:
479480
List[Dict[str, Any]]: A list of deployment configs.
480481
"""
481-
return [
482-
deployment_config.to_json()
483-
for deployment_config in self._get_deployment_configs(
484-
self.config_name, self.instance_type
485-
)
486-
]
482+
return deployment_config_response_data(
483+
self._get_deployment_configs(self.config_name, self.instance_type)
484+
)
487485

488486
def _create_sagemaker_model(
489487
self,
@@ -873,71 +871,46 @@ def register_deploy_wrapper(*args, **kwargs):
873871

874872
return model_package
875873

876-
@lru_cache
877-
def _get_deployment_configs_benchmarks_data(
878-
self, config_name: str, instance_type: str
879-
) -> Dict[str, Any]:
874+
@_deployment_config_lru_cache
875+
def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]:
880876
"""Deployment configs benchmark metrics.
881877
882-
Args:
883-
config_name (str): Name of selected deployment config.
884-
instance_type (str): The selected Instance type.
885878
Returns:
886879
Dict[str, List[str]]: Deployment config benchmark data.
887880
"""
888881
return get_metrics_from_deployment_configs(
889-
self._get_deployment_configs(config_name, instance_type)
882+
self._get_deployment_configs(None, None),
890883
)
891884

892-
@lru_cache
893-
def _retrieve_selected_deployment_config(
894-
self, config_name: str, instance_type: str
895-
) -> Optional[DeploymentConfigMetadata]:
896-
"""Retrieve the deployment config to apply to `This` model.
897-
898-
Args:
899-
config_name (str): The name of the deployment config to retrieve.
900-
instance_type (str): The instance type of the deployment config to retrieve.
901-
Returns:
902-
Optional[Dict[str, Any]]: The retrieved deployment config.
903-
"""
904-
if config_name is None:
905-
return None
906-
907-
for deployment_config in self._get_deployment_configs(config_name, instance_type):
908-
if deployment_config.deployment_config_name == config_name:
909-
return deployment_config
910-
return None
911-
912-
@lru_cache
885+
@_deployment_config_lru_cache
913886
def _get_deployment_configs(
914-
self, selected_config_name: str, selected_instance_type: str
887+
self, selected_config_name: Optional[str], selected_instance_type: Optional[str]
915888
) -> List[DeploymentConfigMetadata]:
916889
"""Retrieve deployment configs metadata.
917890
918891
Args:
919-
selected_config_name (str): The name of the selected deployment config.
920-
selected_instance_type (str): The selected instance type.
892+
selected_config_name (Optional[str]): The name of the selected deployment config.
893+
selected_instance_type (Optional[str]): The selected instance type.
921894
"""
922895
deployment_configs = []
923-
if self._metadata_configs is None:
896+
if not self._metadata_configs:
924897
return deployment_configs
925898

926899
err = None
927900
for config_name, metadata_config in self._metadata_configs.items():
928-
if err is None or "is not authorized to perform: #:GetProducts" not in err:
929-
err, metadata_config.benchmark_metrics = (
930-
add_instance_rate_stats_to_benchmark_metrics(
931-
self.region, metadata_config.benchmark_metrics
932-
)
933-
)
934-
935901
resolved_config = metadata_config.resolved_config
936902
if selected_config_name == config_name:
937903
instance_type_to_use = selected_instance_type
938904
else:
939905
instance_type_to_use = resolved_config.get("default_inference_instance_type")
940906

907+
if metadata_config.benchmark_metrics:
908+
err, metadata_config.benchmark_metrics = (
909+
add_instance_rate_stats_to_benchmark_metrics(
910+
self.region, metadata_config.benchmark_metrics
911+
)
912+
)
913+
941914
init_kwargs = get_init_kwargs(
942915
model_id=self.model_id,
943916
instance_type=instance_type_to_use,
@@ -957,9 +930,9 @@ def _get_deployment_configs(
957930
)
958931
deployment_configs.append(deployment_config_metadata)
959932

960-
if err is not None and "is not authorized to perform: #:GetProducts" in err:
933+
if err and err["Code"] == "AccessDeniedException":
961934
error_message = "Instance rate metrics will be omitted. Reason: %s"
962-
JUMPSTART_LOGGER.warning(error_message, err)
935+
JUMPSTART_LOGGER.warning(error_message, err["Message"])
963936

964937
return deployment_configs
965938

src/sagemaker/jumpstart/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -2258,6 +2258,8 @@ def _val_to_json(self, val: Any) -> Any:
22582258
Any: The converted json value.
22592259
"""
22602260
if issubclass(type(val), JumpStartDataHolderType):
2261+
if isinstance(val, JumpStartBenchmarkStat):
2262+
val.name = val.name.replace("_", " ").title()
22612263
return val.to_json()
22622264
if isinstance(val, list):
22632265
list_obj = []

0 commit comments

Comments
 (0)