Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add Sharding Support for Neo Optimization Jobs. #4931

Closed
wants to merge 12 commits into from
14 changes: 14 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def __init__(
self.endpoint_name = None
self.inference_component_name = None
self._is_compiled_model = False
self._is_sharded_model = False
self._compilation_job_name = None
self._is_edge_packaged_model = False
self.inference_recommender_job_results = None
Expand Down Expand Up @@ -1599,6 +1600,19 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._is_sharded_model and self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if endpoint_name:
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -727,7 +730,7 @@ def _optimize_for_jumpstart(
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

optimization_config, override_env = _extract_optimization_config_and_env(
quantization_config, compilation_config
quantization_config, compilation_config, sharding_config
)
if not optimization_config and is_compilation:
override_env = override_env or pysdk_model_env_vars
Expand Down Expand Up @@ -792,7 +795,15 @@ def _optimize_for_jumpstart(
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:

if sharding_config and self.pysdk_model._enable_network_isolation:
logger.warning(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access. Setting it to False."
)
self.pysdk_model._enable_network_isolation = False

if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
return None

Expand Down
37 changes: 36 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def optimize(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1142,6 +1143,8 @@ def optimize(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand Down Expand Up @@ -1170,6 +1173,7 @@ def optimize(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1238,6 +1245,26 @@ def _model_builder_optimize_wrapper(
if quantization_config and compilation_config:
raise ValueError("Quantization config and compilation config are mutually exclusive.")

if sharding_config and (
quantization_config or compilation_config or speculative_decoding_config
):
raise ValueError(
"Sharding config is mutually exclusive and cannot be combined with any "
"other optimization."
)

if sharding_config and (
(env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars)
or (
sharding_config.get("OverrideEnvironment")
and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"]
)
):
raise ValueError(
"OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with "
"sharding config."
)

self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
self.instance_type = instance_type or self.instance_type
self.role_arn = role_arn or self.role_arn
Expand All @@ -1254,6 +1281,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1272,12 +1300,16 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
max_runtime_in_sec=max_runtime_in_sec,
)

if sharding_config:
self.pysdk_model._is_sharded_model = True

if input_args:
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
Expand All @@ -1297,6 +1329,7 @@ def _optimize_for_hf(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1312,6 +1345,8 @@ def _optimize_for_hf(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1327,7 +1362,7 @@ def _optimize_for_hf(
self.pysdk_model, speculative_decoding_config, False
)

if quantization_config or compilation_config:
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,16 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:


def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.

Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.

Returns:
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
Expand All @@ -279,6 +282,8 @@ def _extract_optimization_config_and_env(
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
"OverrideEnvironment"
)
if sharding_config:
return {"ModelShardingConfig": sharding_config}, sharding_config.get("OverrideEnvironment")
return None, None


Expand Down
50 changes: 50 additions & 0 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,56 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
sagemaker_session.create_model.reset_mock()


@patch("sagemaker.utils.repack_model")
@patch("sagemaker.fw_utils.tar_and_upload_dir")
def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
repack_model, tar_and_uload_dir, sagemaker_session
):
framework_model_classes_to_kwargs = {
HuggingFaceModel: {
"pytorch_version": "1.7.1",
"py_version": "py36",
"transformers_version": "4.6.1",
},
}

sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False)

source_dir = "s3://blah/blah/blah"
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
test_sharded_model = framework_model_class(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
model_data=source_dir,
**kwargs,
)
test_sharded_model._is_sharded_model = True
test_sharded_model.deploy(
instance_type="ml.m2.xlarge",
initial_instance_count=INSTANCE_COUNT,
endpoint_type=EndpointType.MODEL_BASED,
resources=ResourceRequirements(
requests={
"num_accelerators": 1,
"memory": 8192,
"copies": 1,
},
limits={},
),
)

# Verified inference component based endpoint and inference component creation
# path
sagemaker_session.endpoint_in_service_or_not.assert_called_once()
sagemaker_session.create_model.assert_called_once()
sagemaker_session.create_inference_component.assert_called_once()

sagemaker_session.create_inference_component.reset_mock()
sagemaker_session.endpoint_in_service_or_not.reset_mock()
sagemaker_session.create_model.reset_mock()


@patch("sagemaker.utils.repack_model")
def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):

Expand Down
51 changes: 51 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,57 @@ def test_optimize_quantize_for_jumpstart(

self.assertIsNotNone(out_put)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_sharding_for_jumpstart(
self,
mock_serve_settings,
mock_telemetry,
):
mock_sagemaker_session = Mock()

mock_pysdk_model = Mock()
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
mock_pysdk_model.model_data = mock_model_data
mock_pysdk_model.image_uri = mock_tgi_image_uri
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]

sample_input = {
"inputs": "The diamondback terrapin or simply terrapin is a species "
"of turtle native to the brackish coastal tidal marshes of the",
"parameters": {"max_new_tokens": 1024},
}
sample_output = [
{
"generated_text": "The diamondback terrapin or simply terrapin is a "
"species of turtle native to the brackish coastal "
"tidal marshes of the east coast."
}
]

model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
schema_builder=SchemaBuilder(sample_input, sample_output),
sagemaker_session=mock_sagemaker_session,
)

model_builder.pysdk_model = mock_pysdk_model

out_put = model_builder._optimize_for_jumpstart(
accept_eula=True,
sharding_config={
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
},
env_vars={
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
"OPTION_MAX_ROLLING_BATCH_SIZE": "2",
},
output_path="s3://bucket/code/",
)

self.assertIsNotNone(out_put)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
@patch(
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,39 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
lambda: model_builder.optimize(
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
mock_sagemaker_session = Mock()
model_builder = ModelBuilder(
model="meta-textgeneration-llama-3-70b",
sagemaker_session=mock_sagemaker_session,
)

self.assertRaisesRegex(
ValueError,
"OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.",
lambda: model_builder.optimize(
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
),
)

@patch.object(ModelBuilder, "_prepare_for_mode")
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
def test_optimize_for_hf_with_custom_s3_path(
Expand Down
Loading
Loading