Skip to content

Update: SM Endpoint Routing Strategy Support. #4702

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 7 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/sagemaker/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ class EndpointType(Enum):
INFERENCE_COMPONENT_BASED = (
"InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
)


class RoutingStrategy(Enum):
"""Strategy for routing https traffics."""

RANDOM = "RANDOM"
"""The endpoint routes each request to a randomly chosen instance.
"""
LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS"
"""The endpoint routes requests to the specific instances that have
more capacity to process them.
"""
1 change: 1 addition & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def deploy(
endpoint_type=kwargs.get("endpoint_type", None),
resources=kwargs.get("resources", None),
managed_instance_scaling=kwargs.get("managed_instance_scaling", None),
routing_config=kwargs.get("routing_config", None),
)

def register(
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def get_deploy_kwargs(
resources: Optional[ResourceRequirements] = None,
managed_instance_scaling: Optional[str] = None,
endpoint_type: Optional[EndpointType] = None,
routing_config: Optional[Dict[str, Any]] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -586,6 +587,7 @@ def get_deploy_kwargs(
accept_eula=accept_eula,
endpoint_logging=endpoint_logging,
resources=resources,
routing_config=routing_config,
)

deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import absolute_import

from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any
from botocore.exceptions import ClientError

from sagemaker import payloads
Expand Down Expand Up @@ -496,6 +496,7 @@ def deploy(
resources: Optional[ResourceRequirements] = None,
managed_instance_scaling: Optional[str] = None,
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
routing_config: Optional[Dict[str, Any]] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.

Expand Down Expand Up @@ -590,6 +591,8 @@ def deploy(
endpoint.
endpoint_type (EndpointType): The type of endpoint used to deploy models.
(Default: EndpointType.MODEL_BASED).
routing_config (Optional[Dict]): Settings the control how the endpoint routes
incoming traffic to the instances that the endpoint hosts.

Raises:
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
Expand Down Expand Up @@ -625,6 +628,7 @@ def deploy(
managed_instance_scaling=managed_instance_scaling,
endpoint_type=endpoint_type,
model_type=self.model_type,
routing_config=routing_config,
)
if (
self.model_type == JumpStartModelType.PROPRIETARY
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"endpoint_logging",
"resources",
"endpoint_type",
"routing_config",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1658,6 +1659,7 @@ def __init__(
endpoint_logging: Optional[bool] = None,
resources: Optional[ResourceRequirements] = None,
endpoint_type: Optional[EndpointType] = None,
routing_config: Optional[Dict[str, Any]] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""

Expand Down Expand Up @@ -1690,6 +1692,7 @@ def __init__(
self.endpoint_logging = endpoint_logging
self.resources = resources
self.endpoint_type = endpoint_type
self.routing_config = routing_config


class JumpStartEstimatorInitKwargs(JumpStartKwargs):
Expand Down
17 changes: 16 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import re
import copy
from typing import List, Dict, Optional, Union
from typing import List, Dict, Optional, Union, Any

import sagemaker
from sagemaker import (
Expand Down Expand Up @@ -66,6 +66,7 @@
resolve_nested_dict_value_from_config,
format_tags,
Tags,
_resolve_routing_config,
)
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.predictor_async import AsyncPredictor
Expand Down Expand Up @@ -1309,6 +1310,7 @@ def deploy(
resources: Optional[ResourceRequirements] = None,
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
managed_instance_scaling: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -1406,6 +1408,15 @@ def deploy(
Endpoint. (Default: None).
endpoint_type (Optional[EndpointType]): The type of an endpoint used to deploy models.
(Default: EndpointType.MODEL_BASED).
routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming
traffic to the instances that the endpoint hosts.
Currently, support dictionary key ``RoutingStrategy``.

.. code:: python

{
"RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM
}
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
Expand Down Expand Up @@ -1458,6 +1469,8 @@ def deploy(
if self.role is None:
raise ValueError("Role can not be null for deploying a model")

routing_config = _resolve_routing_config(routing_config)

if (
inference_recommendation_id is not None
or self.inference_recommender_job_results is not None
Expand Down Expand Up @@ -1543,6 +1556,7 @@ def deploy(
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
managed_instance_scaling=managed_instance_scaling_config,
routing_config=routing_config,
)

self.sagemaker_session.endpoint_from_production_variants(
Expand Down Expand Up @@ -1625,6 +1639,7 @@ def deploy(
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
routing_config=routing_config,
)
if endpoint_name:
self.endpoint_name = endpoint_name
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
in order for model builder to build the artifacts correctly (according
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
``TRITON``,``TGI``, and ``TEI``.
``TRITON``, ``TGI``, and ``TEI``.
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
new models without task metadata in the Hub, adding unsupported task types will throw
Expand Down
31 changes: 31 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_log_sagemaker_config_single_substitution,
_log_sagemaker_config_merge,
)
from sagemaker.enums import RoutingStrategy
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.workflow.entities import PipelineVariable
Expand Down Expand Up @@ -1655,3 +1656,33 @@ def deep_override_dict(
)
flattened_dict1.update(flattened_dict2)
return unflatten_dict(flattened_dict1) if flattened_dict1 else {}


def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Resolve Routing Config

Args:
routing_config (Optional[Dict[str, Any]]): The routing config.

Returns:
Optional[Dict[str, Any]]: The resolved routing config.

Raises:
ValueError: If the RoutingStrategy is invalid.
"""

if routing_config:
routing_strategy = routing_config.get("RoutingStrategy", None)
if routing_strategy:
if isinstance(routing_strategy, RoutingStrategy):
return {"RoutingStrategy": routing_strategy.name}
if isinstance(routing_strategy, str) and (
routing_strategy.upper() == RoutingStrategy.RANDOM.name
or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
):
return {"RoutingStrategy": routing_strategy.upper()}
raise ValueError(
"RoutingStrategy must be either RoutingStrategy.RANDOM "
"or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
)
return None
5 changes: 5 additions & 0 deletions tests/unit/sagemaker/model/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
routing_config=None,
)

sagemaker_session.create_model.assert_called_with(
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_deploy_accelerator_type(
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
routing_config=None,
)

sagemaker_session.endpoint_from_production_variants.assert_called_with(
Expand Down Expand Up @@ -506,6 +508,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model,
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
routing_config=None,
)

sagemaker_session.endpoint_from_production_variants.assert_called_with(
Expand Down Expand Up @@ -938,6 +941,7 @@ def test_deploy_customized_volume_size_and_timeout(
volume_size=volume_size_gb,
model_data_download_timeout=model_data_download_timeout_sec,
container_startup_health_check_timeout=startup_health_check_timeout_sec,
routing_config=None,
)

sagemaker_session.create_model.assert_called_with(
Expand Down Expand Up @@ -987,6 +991,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
routing_config=None,
)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=name_from_base(MODEL_NAME),
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from mock import call, patch, Mock, MagicMock, PropertyMock

import sagemaker
from sagemaker.enums import RoutingStrategy
from sagemaker.experiments._run_context import _RunContext
from sagemaker.session_settings import SessionSettings
from sagemaker.utils import (
Expand All @@ -50,6 +51,7 @@
_is_bad_link,
custom_extractall_tarfile,
can_model_package_source_uri_autopopulate,
_resolve_routing_config,
)
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
Expand Down Expand Up @@ -1866,3 +1868,30 @@ def test_deep_override_skip_keys(self):
expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]}

self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result)


@pytest.mark.parametrize(
"routing_config, expected",
[
({"RoutingStrategy": RoutingStrategy.RANDOM}, {"RoutingStrategy": "RANDOM"}),
({"RoutingStrategy": "RANDOM"}, {"RoutingStrategy": "RANDOM"}),
(
{"RoutingStrategy": RoutingStrategy.LEAST_OUTSTANDING_REQUESTS},
{"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
),
(
{"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
{"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
),
({"RoutingStrategy": None}, None),
(None, None),
],
)
def test_resolve_routing_config(routing_config, expected):
res = _resolve_routing_config(routing_config)

assert res == expected


def test_resolve_routing_config_ex():
pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"}))