@@ -744,12 +744,12 @@ def _get_regional_property(
744
744
745
745
746
746
class JumpStartBenchmarkStat (JumpStartDataHolderType ):
747
- """Data class JumpStart benchmark stats ."""
747
+ """Data class JumpStart benchmark stat ."""
748
748
749
749
__slots__ = ["name" , "value" , "unit" ]
750
750
751
751
def __init__ (self , spec : Dict [str , Any ]):
752
- """Initializes a JumpStartBenchmarkStat object
752
+ """Initializes a JumpStartBenchmarkStat object.
753
753
754
754
Args:
755
755
spec (Dict[str, Any]): Dictionary representation of benchmark stat.
@@ -858,7 +858,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
858
858
"model_subscription_link" ,
859
859
]
860
860
861
- def __init__ (self , fields : Optional [ Dict [str , Any ] ]):
861
+ def __init__ (self , fields : Dict [str , Any ]):
862
862
"""Initializes a JumpStartMetadataFields object.
863
863
864
864
Args:
@@ -877,7 +877,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
877
877
self .version : str = json_obj .get ("version" )
878
878
self .min_sdk_version : str = json_obj .get ("min_sdk_version" )
879
879
self .incremental_training_supported : bool = bool (
880
- json_obj .get ("incremental_training_supported" )
880
+ json_obj .get ("incremental_training_supported" , False )
881
881
)
882
882
self .hosting_ecr_specs : Optional [JumpStartECRSpecs ] = (
883
883
JumpStartECRSpecs (json_obj ["hosting_ecr_specs" ])
@@ -1038,7 +1038,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
1038
1038
1039
1039
__slots__ = slots + JumpStartMetadataBaseFields .__slots__
1040
1040
1041
- def __init__ ( # pylint: disable=super-init-not-called
1041
+ def __init__ (
1042
1042
self ,
1043
1043
component_name : str ,
1044
1044
component : Optional [Dict [str , Any ]],
@@ -1049,7 +1049,10 @@ def __init__( # pylint: disable=super-init-not-called
1049
1049
component_name (str): Name of the component.
1050
1050
component (Dict[str, Any]):
1051
1051
Dictionary representation of the config component.
1052
+ Raises:
1053
+ ValueError: If the component field is invalid.
1052
1054
"""
1055
+ super ().__init__ (component )
1053
1056
self .component_name = component_name
1054
1057
self .from_json (component )
1055
1058
@@ -1080,7 +1083,7 @@ def __init__(
1080
1083
self ,
1081
1084
base_fields : Dict [str , Any ],
1082
1085
config_components : Dict [str , JumpStartConfigComponent ],
1083
- benchmark_metrics : Dict [str , JumpStartBenchmarkStat ],
1086
+ benchmark_metrics : Dict [str , List [ JumpStartBenchmarkStat ] ],
1084
1087
):
1085
1088
"""Initializes a JumpStartMetadataConfig object from its json representation.
1086
1089
@@ -1089,12 +1092,12 @@ def __init__(
1089
1092
The default base fields that are used to construct the final resolved config.
1090
1093
config_components (Dict[str, JumpStartConfigComponent]):
1091
1094
The list of components that are used to construct the resolved config.
1092
- benchmark_metrics (Dict[str, JumpStartBenchmarkStat]):
1095
+ benchmark_metrics (Dict[str, List[ JumpStartBenchmarkStat] ]):
1093
1096
The dictionary of benchmark metrics with name being the key.
1094
1097
"""
1095
1098
self .base_fields = base_fields
1096
1099
self .config_components : Dict [str , JumpStartConfigComponent ] = config_components
1097
- self .benchmark_metrics : Dict [str , JumpStartBenchmarkStat ] = benchmark_metrics
1100
+ self .benchmark_metrics : Dict [str , List [ JumpStartBenchmarkStat ] ] = benchmark_metrics
1098
1101
self .resolved_metadata_config : Optional [Dict [str , Any ]] = None
1099
1102
1100
1103
def to_json (self ) -> Dict [str , Any ]:
@@ -1104,7 +1107,7 @@ def to_json(self) -> Dict[str, Any]:
1104
1107
1105
1108
@property
1106
1109
def resolved_config (self ) -> Dict [str , Any ]:
1107
- """Returns the final config that is resolved from the list of components.
1110
+ """Returns the final config that is resolved from the components map .
1108
1111
1109
1112
Construct the final config by applying the list of configs from list index,
1110
1113
and apply to the base default fields in the current model specs.
@@ -1139,7 +1142,7 @@ def __init__(
1139
1142
1140
1143
Args:
1141
1144
configs (Dict[str, JumpStartMetadataConfig]):
1142
- List of configs that the current model has .
1145
+ The map of JumpStartMetadataConfig object, with config name being the key .
1143
1146
config_rankings (JumpStartConfigRanking):
1144
1147
Config ranking class represents the ranking of the configs in the model.
1145
1148
scope (JumpStartScriptScope):
@@ -1158,19 +1161,30 @@ def get_top_config_from_ranking(
1158
1161
self ,
1159
1162
ranking_name : str = JumpStartConfigRankingName .DEFAULT ,
1160
1163
instance_type : Optional [str ] = None ,
1161
- ) -> JumpStartMetadataConfig :
1162
- """Gets the best the config based on config ranking."""
1164
+ ) -> Optional [JumpStartMetadataConfig ]:
1165
+ """Gets the best the config based on config ranking.
1166
+
1167
+ Args:
1168
+ ranking_name (str):
1169
+ The ranking name that config priority is based on.
1170
+ instance_type (Optional[str]):
1171
+ The instance type which the config selection is based on.
1172
+
1173
+ Raises:
1174
+ ValueError: If the config exists but missing config ranking.
1175
+ NotImplementedError: If the scope is unrecognized.
1176
+ """
1163
1177
if self .configs and (
1164
1178
not self .config_rankings or not self .config_rankings .get (ranking_name )
1165
1179
):
1166
- raise ValueError ("Config exists but missing config ranking." )
1180
+ raise ValueError (f "Config exists but missing config ranking { ranking_name } ." )
1167
1181
1168
1182
if self .scope == JumpStartScriptScope .INFERENCE :
1169
1183
instance_type_attribute = "supported_inference_instance_types"
1170
1184
elif self .scope == JumpStartScriptScope .TRAINING :
1171
1185
instance_type_attribute = "supported_training_instance_types"
1172
1186
else :
1173
- raise ValueError (f"Unknown script scope { self .scope } " )
1187
+ raise NotImplementedError (f"Unknown script scope { self .scope } " )
1174
1188
1175
1189
rankings = self .config_rankings .get (ranking_name )
1176
1190
for config_name in rankings .rankings :
@@ -1198,12 +1212,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields):
1198
1212
1199
1213
__slots__ = JumpStartMetadataBaseFields .__slots__ + slots
1200
1214
1201
- def __init__ (self , spec : Dict [str , Any ]): # pylint: disable=super-init-not-called
1215
+ def __init__ (self , spec : Dict [str , Any ]):
1202
1216
"""Initializes a JumpStartModelSpecs object from its json representation.
1203
1217
1204
1218
Args:
1205
1219
spec (Dict[str, Any]): Dictionary representation of spec.
1206
1220
"""
1221
+ super ().__init__ (spec )
1207
1222
self .from_json (spec )
1208
1223
if self .inference_configs and self .inference_configs .get_top_config_from_ranking ():
1209
1224
super ().from_json (self .inference_configs .get_top_config_from_ranking ().resolved_config )
@@ -1245,8 +1260,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
1245
1260
),
1246
1261
(
1247
1262
{
1248
- stat_name : JumpStartBenchmarkStat (stat )
1249
- for stat_name , stat in config .get ("benchmark_metrics" ).items ()
1263
+ stat_name : [ JumpStartBenchmarkStat (stat ) for stat in stats ]
1264
+ for stat_name , stats in config .get ("benchmark_metrics" ).items ()
1250
1265
}
1251
1266
if config and config .get ("benchmark_metrics" )
1252
1267
else None
@@ -1297,8 +1312,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
1297
1312
),
1298
1313
(
1299
1314
{
1300
- stat_name : JumpStartBenchmarkStat (stat )
1301
- for stat_name , stat in config .get ("benchmark_metrics" ).items ()
1315
+ stat_name : [ JumpStartBenchmarkStat (stat ) for stat in stats ]
1316
+ for stat_name , stats in config .get ("benchmark_metrics" ).items ()
1302
1317
}
1303
1318
if config and config .get ("benchmark_metrics" )
1304
1319
else None
@@ -1330,13 +1345,26 @@ def set_config(
1330
1345
config_name (str): Name of the config.
1331
1346
scope (JumpStartScriptScope, optional):
1332
1347
Scope of the config. Defaults to JumpStartScriptScope.INFERENCE.
1348
+
1349
+ Raises:
1350
+ ValueError: If the scope is not supported, or cannot find config name.
1333
1351
"""
1334
1352
if scope == JumpStartScriptScope .INFERENCE :
1335
- super (). from_json ( self .inference_configs . configs [ config_name ]. resolved_config )
1353
+ metadata_configs = self .inference_configs
1336
1354
elif scope == JumpStartScriptScope .TRAINING and self .training_supported :
1337
- super (). from_json ( self .training_configs . configs [ config_name ]. resolved_config )
1355
+ metadata_configs = self .training_configs
1338
1356
else :
1339
- raise ValueError (f"Unknown Jumpstart Script scope { scope } ." )
1357
+ raise ValueError (f"Unknown Jumpstart script scope { scope } ." )
1358
+
1359
+ config_object = metadata_configs .configs .get (config_name )
1360
+ if not config_object :
1361
+ error_msg = f"Cannot find Jumpstart config name { config_name } . "
1362
+ config_names = list (metadata_configs .configs .keys ())
1363
+ if config_names :
1364
+ error_msg += f"List of config names that is supported by the model: { config_names } "
1365
+ raise ValueError (error_msg )
1366
+
1367
+ super ().from_json (config_object .resolved_config )
1340
1368
1341
1369
def supports_prepacked_inference (self ) -> bool :
1342
1370
"""Returns True if the model has a prepacked inference artifact."""
0 commit comments