Skip to content

fix: mainline alt config parsing #4602

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 7 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,12 +744,12 @@ def _get_regional_property(


class JumpStartBenchmarkStat(JumpStartDataHolderType):
"""Data class JumpStart benchmark stats."""
"""Data class JumpStart benchmark stat."""

__slots__ = ["name", "value", "unit"]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartBenchmarkStat object
"""Initializes a JumpStartBenchmarkStat object.

Args:
spec (Dict[str, Any]): Dictionary representation of benchmark stat.
Expand Down Expand Up @@ -858,7 +858,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"model_subscription_link",
]

def __init__(self, fields: Optional[Dict[str, Any]]):
def __init__(self, fields: Dict[str, Any]):
"""Initializes a JumpStartMetadataFields object.

Args:
Expand All @@ -877,7 +877,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.version: str = json_obj.get("version")
self.min_sdk_version: str = json_obj.get("min_sdk_version")
self.incremental_training_supported: bool = bool(
json_obj.get("incremental_training_supported")
json_obj.get("incremental_training_supported", False)
)
self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = (
JumpStartECRSpecs(json_obj["hosting_ecr_specs"])
Expand Down Expand Up @@ -1038,7 +1038,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):

__slots__ = slots + JumpStartMetadataBaseFields.__slots__

def __init__( # pylint: disable=super-init-not-called
def __init__(
self,
component_name: str,
component: Optional[Dict[str, Any]],
Expand All @@ -1049,7 +1049,10 @@ def __init__( # pylint: disable=super-init-not-called
component_name (str): Name of the component.
component (Dict[str, Any]):
Dictionary representation of the config component.
Raises:
ValueError: If the component field is invalid.
"""
super().__init__(component)
self.component_name = component_name
self.from_json(component)

Expand Down Expand Up @@ -1080,7 +1083,7 @@ def __init__(
self,
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, JumpStartBenchmarkStat],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
):
"""Initializes a JumpStartMetadataConfig object from its json representation.

Expand All @@ -1089,12 +1092,12 @@ def __init__(
The default base fields that are used to construct the final resolved config.
config_components (Dict[str, JumpStartConfigComponent]):
The list of components that are used to construct the resolved config.
benchmark_metrics (Dict[str, JumpStartBenchmarkStat]):
benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]):
The dictionary of benchmark metrics with name being the key.
"""
self.base_fields = base_fields
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, JumpStartBenchmarkStat] = benchmark_metrics
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.resolved_metadata_config: Optional[Dict[str, Any]] = None

def to_json(self) -> Dict[str, Any]:
Expand All @@ -1104,7 +1107,7 @@ def to_json(self) -> Dict[str, Any]:

@property
def resolved_config(self) -> Dict[str, Any]:
"""Returns the final config that is resolved from the list of components.
"""Returns the final config that is resolved from the components map.

Construct the final config by applying the list of configs from list index,
and apply to the base default fields in the current model specs.
Expand Down Expand Up @@ -1139,7 +1142,7 @@ def __init__(

Args:
configs (Dict[str, JumpStartMetadataConfig]):
List of configs that the current model has.
The map of JumpStartMetadataConfig object, with config name being the key.
config_rankings (JumpStartConfigRanking):
Config ranking class represents the ranking of the configs in the model.
scope (JumpStartScriptScope):
Expand All @@ -1158,19 +1161,30 @@ def get_top_config_from_ranking(
self,
ranking_name: str = JumpStartConfigRankingName.DEFAULT,
instance_type: Optional[str] = None,
) -> JumpStartMetadataConfig:
"""Gets the best the config based on config ranking."""
) -> Optional[JumpStartMetadataConfig]:
"""Gets the best the config based on config ranking.

Args:
ranking_name (str):
The ranking name that config priority is based on.
instance_type (Optional[str]):
The instance type which the config selection is based on.

Raises:
ValueError: If the config exists but missing config ranking.
NotImplementedError: If the scope is unrecognized.
"""
if self.configs and (
not self.config_rankings or not self.config_rankings.get(ranking_name)
):
raise ValueError("Config exists but missing config ranking.")
raise ValueError(f"Config exists but missing config ranking {ranking_name}.")

if self.scope == JumpStartScriptScope.INFERENCE:
instance_type_attribute = "supported_inference_instance_types"
elif self.scope == JumpStartScriptScope.TRAINING:
instance_type_attribute = "supported_training_instance_types"
else:
raise ValueError(f"Unknown script scope {self.scope}")
raise NotImplementedError(f"Unknown script scope {self.scope}")

rankings = self.config_rankings.get(ranking_name)
for config_name in rankings.rankings:
Expand Down Expand Up @@ -1198,12 +1212,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields):

__slots__ = JumpStartMetadataBaseFields.__slots__ + slots

def __init__(self, spec: Dict[str, Any]): # pylint: disable=super-init-not-called
def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartModelSpecs object from its json representation.

Args:
spec (Dict[str, Any]): Dictionary representation of spec.
"""
super().__init__(spec)
self.from_json(spec)
if self.inference_configs and self.inference_configs.get_top_config_from_ranking():
super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config)
Expand Down Expand Up @@ -1245,8 +1260,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
),
(
{
stat_name: JumpStartBenchmarkStat(stat)
for stat_name, stat in config.get("benchmark_metrics").items()
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
Expand Down Expand Up @@ -1297,8 +1312,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
),
(
{
stat_name: JumpStartBenchmarkStat(stat)
for stat_name, stat in config.get("benchmark_metrics").items()
stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
for stat_name, stats in config.get("benchmark_metrics").items()
}
if config and config.get("benchmark_metrics")
else None
Expand Down Expand Up @@ -1330,13 +1345,26 @@ def set_config(
config_name (str): Name of the config.
scope (JumpStartScriptScope, optional):
Scope of the config. Defaults to JumpStartScriptScope.INFERENCE.

Raises:
ValueError: If the scope is not supported, or cannot find config name.
"""
if scope == JumpStartScriptScope.INFERENCE:
super().from_json(self.inference_configs.configs[config_name].resolved_config)
metadata_configs = self.inference_configs
elif scope == JumpStartScriptScope.TRAINING and self.training_supported:
super().from_json(self.training_configs.configs[config_name].resolved_config)
metadata_configs = self.training_configs
else:
raise ValueError(f"Unknown Jumpstart Script scope {scope}.")
raise ValueError(f"Unknown Jumpstart script scope {scope}.")

config_object = metadata_configs.configs.get(config_name)
if not config_object:
error_msg = f"Cannot find Jumpstart config name {config_name}. "
config_names = list(metadata_configs.configs.keys())
if config_names:
error_msg += f"List of config names that is supported by the model: {config_names}"
raise ValueError(error_msg)

super().from_json(config_object.resolved_config)

def supports_prepacked_inference(self) -> bool:
"""Returns True if the model has a prepacked inference artifact."""
Expand Down
32 changes: 21 additions & 11 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6270,6 +6270,10 @@
"framework_version": "1.5.0",
"py_version": "py3",
},
"default_inference_instance_type": "ml.p2.xlarge",
"supported_inference_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"],
"default_training_instance_type": "ml.p2.xlarge",
"supported_training_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"],
"hosting_artifact_key": "pytorch-infer/infer-pytorch-eqa-bert-base-cased.tar.gz",
"hosting_script_key": "source-directory-tarballs/pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz",
"inference_vulnerable": False,
Expand Down Expand Up @@ -7658,25 +7662,25 @@
"inference_configs": {
"neuron-inference": {
"benchmark_metrics": {
"ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
"ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}]
},
"component_names": ["neuron-base"],
"component_names": ["neuron-inference"],
},
"neuron-inference-budget": {
"benchmark_metrics": {
"ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
"ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}]
},
"component_names": ["neuron-base"],
},
"gpu-inference-budget": {
"benchmark_metrics": {
"ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
"ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}]
},
"component_names": ["gpu-inference-budget"],
},
"gpu-inference": {
"benchmark_metrics": {
"ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
"ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}]
},
"component_names": ["gpu-inference"],
},
Expand All @@ -7686,7 +7690,13 @@
"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"]
},
"neuron-inference": {
"default_inference_instance_type": "ml.inf2.xlarge",
"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"],
"hosting_ecr_specs": {
"framework": "huggingface-llm-neuronx",
"framework_version": "0.0.17",
"py_version": "py310",
},
"hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/",
"hosting_instance_type_variants": {
"regional_aliases": {
Expand Down Expand Up @@ -7738,27 +7748,27 @@
"training_configs": {
"neuron-training": {
"benchmark_metrics": {
"ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"},
"ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"},
"ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}],
"ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}],
},
"component_names": ["neuron-training"],
},
"neuron-training-budget": {
"benchmark_metrics": {
"ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"},
"ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"},
"ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}],
"ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}],
},
"component_names": ["neuron-training-budget"],
},
"gpu-training": {
"benchmark_metrics": {
"ml.p3.2xlarge": {"name": "Latency", "value": "200", "unit": "Tokens/S"},
"ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}],
},
"component_names": ["gpu-training"],
},
"gpu-training-budget": {
"benchmark_metrics": {
"ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
"ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}]
},
"component_names": ["gpu-training-budget"],
},
Expand Down
Loading