diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py
index 340d35b250..d3ff3cc934 100644
--- a/src/sagemaker/model.py
+++ b/src/sagemaker/model.py
@@ -372,6 +372,7 @@ def __init__(
         self.endpoint_name = None
         self.inference_component_name = None
         self._is_compiled_model = False
+        self._is_sharded_model = False
         self._compilation_job_name = None
         self._is_edge_packaged_model = False
         self.inference_recommender_job_results = None
@@ -1599,6 +1600,11 @@ def deploy(
             if self._base_name is not None:
                 self._base_name = "-".join((self._base_name, compiled_model_suffix))
 
+        if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
+            logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
+                            "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
+            endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
+
         # Support multiple models on same endpoint
         if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
             if endpoint_name:
diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py
index cfb43b813a..c058337470 100644
--- a/src/sagemaker/serve/builder/jumpstart_builder.py
+++ b/src/sagemaker/serve/builder/jumpstart_builder.py
@@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
         quantization_config: Optional[Dict] = None,
         compilation_config: Optional[Dict] = None,
         speculative_decoding_config: Optional[Dict] = None,
+        sharding_config: Optional[Dict] = None,
         env_vars: Optional[Dict] = None,
         vpc_config: Optional[Dict] = None,
         kms_key: Optional[str] = None,
@@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
             compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
             speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
                 Defaults to ``None``
+            sharding_config (Optional[Dict]): Model sharding configuration.
+                Defaults to ``None``
             env_vars (Optional[Dict]): Additional environment variables to run the optimization
                 container. Defaults to ``None``.
             vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -727,7 +730,7 @@ def _optimize_for_jumpstart(
             pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
 
         optimization_config, override_env = _extract_optimization_config_and_env(
-            quantization_config, compilation_config
+            quantization_config, compilation_config, sharding_config
         )
         if not optimization_config and is_compilation:
             override_env = override_env or pysdk_model_env_vars
@@ -792,7 +795,7 @@ def _optimize_for_jumpstart(
         optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
         if optimization_env_vars:
             self.pysdk_model.env.update(optimization_env_vars)
-        if quantization_config or is_compilation:
+        if quantization_config or sharding_config or is_compilation:
             return create_optimization_job_args
         return None
 
diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py
index d1f1ab6ba2..71469eb44c 100644
--- a/src/sagemaker/serve/builder/model_builder.py
+++ b/src/sagemaker/serve/builder/model_builder.py
@@ -1119,6 +1119,7 @@ def optimize(
         quantization_config: Optional[Dict] = None,
         compilation_config: Optional[Dict] = None,
         speculative_decoding_config: Optional[Dict] = None,
+        sharding_config: Optional[Dict] = None,
         env_vars: Optional[Dict] = None,
         vpc_config: Optional[Dict] = None,
         kms_key: Optional[str] = None,
@@ -1142,6 +1143,8 @@ def optimize(
             compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
             speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
                 Defaults to ``None``
+            sharding_config (Optional[Dict]): Model sharding configuration.
+                Defaults to ``None``
             env_vars (Optional[Dict]): Additional environment variables to run the optimization
                 container. Defaults to ``None``.
             vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1170,6 +1173,7 @@ def optimize(
             quantization_config=quantization_config,
             compilation_config=compilation_config,
             speculative_decoding_config=speculative_decoding_config,
+            sharding_config=sharding_config,
             env_vars=env_vars,
             vpc_config=vpc_config,
             kms_key=kms_key,
@@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
         quantization_config: Optional[Dict] = None,
         compilation_config: Optional[Dict] = None,
         speculative_decoding_config: Optional[Dict] = None,
+        sharding_config: Optional[Dict] = None,
         env_vars: Optional[Dict] = None,
         vpc_config: Optional[Dict] = None,
         kms_key: Optional[str] = None,
@@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
             compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
             speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
                 Defaults to ``None``
+            sharding_config (Optional[Dict]): Model sharding configuration.
+                Defaults to ``None``
             env_vars (Optional[Dict]): Additional environment variables to run the optimization
                 container. Defaults to ``None``.
             vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
         if quantization_config and compilation_config:
             raise ValueError("Quantization config and compilation config are mutually exclusive.")
 
+        if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
+            raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")
+
+        if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):
+            raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")
+
         self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
         self.instance_type = instance_type or self.instance_type
         self.role_arn = role_arn or self.role_arn
@@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
                 quantization_config=quantization_config,
                 compilation_config=compilation_config,
                 speculative_decoding_config=speculative_decoding_config,
+                sharding_config=sharding_config,
                 env_vars=env_vars,
                 vpc_config=vpc_config,
                 kms_key=kms_key,
@@ -1272,12 +1286,16 @@ def _model_builder_optimize_wrapper(
                 quantization_config=quantization_config,
                 compilation_config=compilation_config,
                 speculative_decoding_config=speculative_decoding_config,
+                sharding_config=sharding_config,
                 env_vars=env_vars,
                 vpc_config=vpc_config,
                 kms_key=kms_key,
                 max_runtime_in_sec=max_runtime_in_sec,
             )
 
+        if sharding_config:
+            self.pysdk_model._is_sharded_model = True
+
         if input_args:
             self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
             job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
@@ -1297,6 +1315,7 @@ def _optimize_for_hf(
         quantization_config: Optional[Dict] = None,
         compilation_config: Optional[Dict] = None,
         speculative_decoding_config: Optional[Dict] = None,
+        sharding_config: Optional[Dict] = None,
         env_vars: Optional[Dict] = None,
         vpc_config: Optional[Dict] = None,
         kms_key: Optional[str] = None,
@@ -1312,6 +1331,8 @@ def _optimize_for_hf(
             compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
             speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
                 Defaults to ``None``
+            sharding_config (Optional[Dict]): Model sharding configuration.
+                Defaults to ``None``
             env_vars (Optional[Dict]): Additional environment variables to run the optimization
                 container. Defaults to ``None``.
             vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1327,7 +1348,7 @@ def _optimize_for_hf(
             self.pysdk_model, speculative_decoding_config, False
         )
 
-        if quantization_config or compilation_config:
+        if quantization_config or compilation_config or sharding_config:
             create_optimization_job_args = {
                 "OptimizationJobName": job_name,
                 "DeploymentInstanceType": self.instance_type,
diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py
index 5781c0bade..cacf647bcd 100644
--- a/src/sagemaker/serve/utils/optimize_utils.py
+++ b/src/sagemaker/serve/utils/optimize_utils.py
@@ -259,13 +259,15 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
 
 
 def _extract_optimization_config_and_env(
-    quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
+    quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
+    sharding_config: Optional[Dict] = None
 ) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
     """Extracts optimization config and environment variables.
 
     Args:
         quantization_config (Optional[Dict]): The quantization config.
         compilation_config (Optional[Dict]): The compilation config.
+        sharding_config (Optional[Dict]): The sharding config.
 
     Returns:
         Optional[Tuple[Optional[Dict], Optional[Dict]]]:
@@ -279,6 +281,10 @@ def _extract_optimization_config_and_env(
         return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
             "OverrideEnvironment"
         )
+    if sharding_config:
+        return {"ModelShardingConfig": sharding_config}, sharding_config.get(
+            "OverrideEnvironment"
+        )
     return None, None
 
 
diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py
index e43ad0ed0a..92bb6e05b9 100644
--- a/tests/unit/sagemaker/model/test_model.py
+++ b/tests/unit/sagemaker/model/test_model.py
@@ -958,6 +958,54 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
         sagemaker_session.endpoint_in_service_or_not.reset_mock()
         sagemaker_session.create_model.reset_mock()
 
+@patch("sagemaker.utils.repack_model")
+@patch("sagemaker.fw_utils.tar_and_upload_dir")
+def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
+    repack_model, tar_and_uload_dir, sagemaker_session
+):
+    framework_model_classes_to_kwargs = {
+        HuggingFaceModel: {
+            "pytorch_version": "1.7.1",
+            "py_version": "py36",
+            "transformers_version": "4.6.1"
+        },
+    }
+
+    sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False)
+
+    source_dir = "s3://blah/blah/blah"
+    for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
+        test_sharded_model = framework_model_class(
+            entry_point=ENTRY_POINT_INFERENCE,
+            role=ROLE,
+            sagemaker_session=sagemaker_session,
+            model_data=source_dir,
+            **kwargs,
+        )
+        test_sharded_model._is_sharded_model = True
+        test_sharded_model.deploy(
+            instance_type="ml.m2.xlarge",
+            initial_instance_count=INSTANCE_COUNT,
+            endpoint_type=EndpointType.MODEL_BASED,
+            resources=ResourceRequirements(
+                requests={
+                    "num_accelerators": 1,
+                    "memory": 8192,
+                    "copies": 1,
+                },
+                limits={},
+            ),
+        )
+
+        # Verified inference component based endpoint and inference component creation
+        # path
+        sagemaker_session.endpoint_in_service_or_not.assert_called_once()
+        sagemaker_session.create_model.assert_called_once()
+        sagemaker_session.create_inference_component.assert_called_once()
+
+        sagemaker_session.create_inference_component.reset_mock()
+        sagemaker_session.endpoint_in_service_or_not.reset_mock()
+        sagemaker_session.create_model.reset_mock()
 
 @patch("sagemaker.utils.repack_model")
 def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py
index 248955c273..16c8c5c390 100644
--- a/tests/unit/sagemaker/serve/builder/test_js_builder.py
+++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py
@@ -1198,6 +1198,57 @@ def test_optimize_quantize_for_jumpstart(
 
         self.assertIsNotNone(out_put)
 
+    @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+    @patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
+    def test_optimize_sharding_for_jumpstart(
+        self,
+        mock_serve_settings,
+        mock_telemetry,
+    ):
+        mock_sagemaker_session = Mock()
+
+        mock_pysdk_model = Mock()
+        mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
+        mock_pysdk_model.model_data = mock_model_data
+        mock_pysdk_model.image_uri = mock_tgi_image_uri
+        mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
+        mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
+
+        sample_input = {
+            "inputs": "The diamondback terrapin or simply terrapin is a species "
+            "of turtle native to the brackish coastal tidal marshes of the",
+            "parameters": {"max_new_tokens": 1024},
+        }
+        sample_output = [
+            {
+                "generated_text": "The diamondback terrapin or simply terrapin is a "
+                "species of turtle native to the brackish coastal "
+                "tidal marshes of the east coast."
+            }
+        ]
+
+        model_builder = ModelBuilder(
+            model="meta-textgeneration-llama-3-70b",
+            schema_builder=SchemaBuilder(sample_input, sample_output),
+            sagemaker_session=mock_sagemaker_session,
+        )
+
+        model_builder.pysdk_model = mock_pysdk_model
+
+        out_put = model_builder._optimize_for_jumpstart(
+            accept_eula=True,
+            sharding_config={
+                "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
+            },
+            env_vars={
+                "OPTION_TENSOR_PARALLEL_DEGREE": "1",
+                "OPTION_MAX_ROLLING_BATCH_SIZE": "2",
+            },
+            output_path="s3://bucket/code/",
+        )
+
+        self.assertIsNotNone(out_put)
+
     @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
     @patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
     @patch(
diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py
index b50aa17c34..6e0f1bfe5e 100644
--- a/tests/unit/sagemaker/serve/builder/test_model_builder.py
+++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py
@@ -2667,6 +2667,39 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
             ),
         )
 
+    @patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
+    def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
+        mock_sagemaker_session = Mock()
+        model_builder = ModelBuilder(
+            model="meta-textgeneration-llama-3-70b",
+            sagemaker_session=mock_sagemaker_session,
+        )
+
+        self.assertRaisesRegex(
+            ValueError,
+            "Sharding config is mutually exclusive and cannot be combined with any other optimization.",
+            lambda: model_builder.optimize(
+                compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
+                sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
+            ),
+        )
+
+    @patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
+    def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
+        mock_sagemaker_session = Mock()
+        model_builder = ModelBuilder(
+            model="meta-textgeneration-llama-3-70b",
+            sagemaker_session=mock_sagemaker_session,
+        )
+
+        self.assertRaisesRegex(
+            ValueError,
+            "OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
+            lambda: model_builder.optimize(
+                sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
+            ),
+        )
+
     @patch.object(ModelBuilder, "_prepare_for_mode")
     @patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
     def test_optimize_for_hf_with_custom_s3_path(
diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py
index a8dc6d74f4..38e4d0c6fe 100644
--- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py
+++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py
@@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):
 
 
 @pytest.mark.parametrize(
-    "quantization_config, compilation_config, expected_config, expected_env",
+    "quantization_config, compilation_config, sharding_config, expected_config, expected_env",
     [
         (
             None,
@@ -270,6 +270,7 @@ def test_is_s3_uri(s3_uri, expected):
                     "OPTION_TENSOR_PARALLEL_DEGREE": "2",
                 }
             },
+            None,
             {
                 "ModelCompilationConfig": {
                     "OverrideEnvironment": {
@@ -288,6 +289,7 @@ def test_is_s3_uri(s3_uri, expected):
                 }
             },
             None,
+            None,
             {
                 "ModelQuantizationConfig": {
                     "OverrideEnvironment": {
@@ -299,13 +301,32 @@ def test_is_s3_uri(s3_uri, expected):
                 "OPTION_TENSOR_PARALLEL_DEGREE": "2",
             },
         ),
-        (None, None, None, None),
+        (
+            None,
+            None,
+            {
+                "OverrideEnvironment": {
+                    "OPTION_TENSOR_PARALLEL_DEGREE": "2",
+                }
+            },
+            {
+                "ModelShardingConfig": {
+                    "OverrideEnvironment": {
+                        "OPTION_TENSOR_PARALLEL_DEGREE": "2",
+                    }
+                },
+            },
+            {
+                "OPTION_TENSOR_PARALLEL_DEGREE": "2",
+            },
+        ),
+        (None, None, None, None, None),
     ],
 )
 def test_extract_optimization_config_and_env(
-    quantization_config, compilation_config, expected_config, expected_env
+    quantization_config, compilation_config, sharding_config, expected_config, expected_env
 ):
-    assert _extract_optimization_config_and_env(quantization_config, compilation_config) == (
+    assert _extract_optimization_config_and_env(quantization_config, compilation_config, sharding_config) == (
         expected_config,
         expected_env,
     )