14
14
15
15
from __future__ import absolute_import
16
16
17
- from functools import lru_cache
18
17
from typing import Dict , List , Optional , Any , Union
19
18
import pandas as pd
20
19
from botocore .exceptions import ClientError
48
47
get_jumpstart_configs ,
49
48
get_metrics_from_deployment_configs ,
50
49
add_instance_rate_stats_to_benchmark_metrics ,
50
+ deployment_config_response_data ,
51
+ _deployment_config_lru_cache ,
51
52
)
52
53
from sagemaker .jumpstart .constants import JUMPSTART_LOGGER
53
54
from sagemaker .jumpstart .enums import JumpStartModelType
@@ -449,10 +450,12 @@ def deployment_config(self) -> Optional[Dict[str, Any]]:
449
450
Returns:
450
451
Optional[Dict[str, Any]]: Deployment config.
451
452
"""
452
- deployment_config = self ._retrieve_selected_deployment_config (
453
- self .config_name , self .instance_type
454
- )
455
- return deployment_config .to_json () if deployment_config is not None else None
453
+ if self .config_name is None :
454
+ return None
455
+ for config in self .list_deployment_configs ():
456
+ if config .get ("DeploymentConfigName" ) == self .config_name :
457
+ return config
458
+ return None
456
459
457
460
@property
458
461
def benchmark_metrics (self ) -> pd .DataFrame :
@@ -461,29 +464,24 @@ def benchmark_metrics(self) -> pd.DataFrame:
461
464
Returns:
462
465
Benchmark Metrics: Pandas DataFrame object.
463
466
"""
464
- benchmark_metrics_data = self ._get_deployment_configs_benchmarks_data (
465
- self .config_name , self .instance_type
466
- )
467
- keys = list (benchmark_metrics_data .keys ())
468
- df = pd .DataFrame (benchmark_metrics_data ).sort_values (by = [keys [0 ], keys [1 ]])
469
- return df
467
+ df = pd .DataFrame (self ._get_deployment_configs_benchmarks_data ())
468
+ default_mask = df .apply (lambda row : any ("Default" in str (val ) for val in row ), axis = 1 )
469
+ sorted_df = pd .concat ([df [default_mask ], df [~ default_mask ]])
470
+ return sorted_df
470
471
471
- def display_benchmark_metrics (self ) -> None :
472
+ def display_benchmark_metrics (self , * args , ** kwargs ) -> None :
472
473
"""Display deployment configs benchmark metrics."""
473
- print (self .benchmark_metrics .to_markdown (index = False ))
474
+ print (self .benchmark_metrics .to_markdown (index = False ), * args , ** kwargs )
474
475
475
476
def list_deployment_configs (self ) -> List [Dict [str , Any ]]:
476
477
"""List deployment configs for ``This`` model.
477
478
478
479
Returns:
479
480
List[Dict[str, Any]]: A list of deployment configs.
480
481
"""
481
- return [
482
- deployment_config .to_json ()
483
- for deployment_config in self ._get_deployment_configs (
484
- self .config_name , self .instance_type
485
- )
486
- ]
482
+ return deployment_config_response_data (
483
+ self ._get_deployment_configs (self .config_name , self .instance_type )
484
+ )
487
485
488
486
def _create_sagemaker_model (
489
487
self ,
@@ -873,71 +871,46 @@ def register_deploy_wrapper(*args, **kwargs):
873
871
874
872
return model_package
875
873
876
- @lru_cache
877
- def _get_deployment_configs_benchmarks_data (
878
- self , config_name : str , instance_type : str
879
- ) -> Dict [str , Any ]:
874
+ @_deployment_config_lru_cache
875
+ def _get_deployment_configs_benchmarks_data (self ) -> Dict [str , Any ]:
880
876
"""Deployment configs benchmark metrics.
881
877
882
- Args:
883
- config_name (str): Name of selected deployment config.
884
- instance_type (str): The selected Instance type.
885
878
Returns:
886
879
Dict[str, List[str]]: Deployment config benchmark data.
887
880
"""
888
881
return get_metrics_from_deployment_configs (
889
- self ._get_deployment_configs (config_name , instance_type )
882
+ self ._get_deployment_configs (None , None ),
890
883
)
891
884
892
- @lru_cache
893
- def _retrieve_selected_deployment_config (
894
- self , config_name : str , instance_type : str
895
- ) -> Optional [DeploymentConfigMetadata ]:
896
- """Retrieve the deployment config to apply to `This` model.
897
-
898
- Args:
899
- config_name (str): The name of the deployment config to retrieve.
900
- instance_type (str): The instance type of the deployment config to retrieve.
901
- Returns:
902
- Optional[Dict[str, Any]]: The retrieved deployment config.
903
- """
904
- if config_name is None :
905
- return None
906
-
907
- for deployment_config in self ._get_deployment_configs (config_name , instance_type ):
908
- if deployment_config .deployment_config_name == config_name :
909
- return deployment_config
910
- return None
911
-
912
- @lru_cache
885
+ @_deployment_config_lru_cache
913
886
def _get_deployment_configs (
914
- self , selected_config_name : str , selected_instance_type : str
887
+ self , selected_config_name : Optional [ str ] , selected_instance_type : Optional [ str ]
915
888
) -> List [DeploymentConfigMetadata ]:
916
889
"""Retrieve deployment configs metadata.
917
890
918
891
Args:
919
- selected_config_name (str): The name of the selected deployment config.
920
- selected_instance_type (str): The selected instance type.
892
+ selected_config_name (Optional[ str] ): The name of the selected deployment config.
893
+ selected_instance_type (Optional[ str] ): The selected instance type.
921
894
"""
922
895
deployment_configs = []
923
- if self ._metadata_configs is None :
896
+ if not self ._metadata_configs :
924
897
return deployment_configs
925
898
926
899
err = None
927
900
for config_name , metadata_config in self ._metadata_configs .items ():
928
- if err is None or "is not authorized to perform: #:GetProducts" not in err :
929
- err , metadata_config .benchmark_metrics = (
930
- add_instance_rate_stats_to_benchmark_metrics (
931
- self .region , metadata_config .benchmark_metrics
932
- )
933
- )
934
-
935
901
resolved_config = metadata_config .resolved_config
936
902
if selected_config_name == config_name :
937
903
instance_type_to_use = selected_instance_type
938
904
else :
939
905
instance_type_to_use = resolved_config .get ("default_inference_instance_type" )
940
906
907
+ if metadata_config .benchmark_metrics :
908
+ err , metadata_config .benchmark_metrics = (
909
+ add_instance_rate_stats_to_benchmark_metrics (
910
+ self .region , metadata_config .benchmark_metrics
911
+ )
912
+ )
913
+
941
914
init_kwargs = get_init_kwargs (
942
915
model_id = self .model_id ,
943
916
instance_type = instance_type_to_use ,
@@ -957,9 +930,9 @@ def _get_deployment_configs(
957
930
)
958
931
deployment_configs .append (deployment_config_metadata )
959
932
960
- if err is not None and "is not authorized to perform: #:GetProducts" in err :
933
+ if err and err [ "Code" ] == "AccessDeniedException" :
961
934
error_message = "Instance rate metrics will be omitted. Reason: %s"
962
- JUMPSTART_LOGGER .warning (error_message , err )
935
+ JUMPSTART_LOGGER .warning (error_message , err [ "Message" ] )
963
936
964
937
return deployment_configs
965
938
0 commit comments