Skip to content

Commit 7b211fe

Browse files
authored
fix: Add back serialization for automatic speech recognition (#4586)
* Add back serialization for automatic speech recognition * Separate out integ test * Fix formatting * Update model
1 parent 995e78b commit 7b211fe

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

src/sagemaker/base_serializers.py

+2
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ def serialize(self, data):
397397
raise ValueError(f"Could not open/read file: {data}. {e}")
398398
if isinstance(data, bytes):
399399
return data
400+
if isinstance(data, dict) and "data" in data:
401+
return self.serialize(data["data"])
400402

401403
raise ValueError(f"Object of type {type(data)} is not Data serializable.")
402404

src/sagemaker/serve/builder/schema_builder.py

+5
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ def _get_serializer(self, obj):
164164
return StringSerializer()
165165
if _is_jsonable(obj):
166166
return JSONSerializerWrapper()
167+
if isinstance(obj, dict) and "content_type" in obj:
168+
try:
169+
return DataSerializer(content_type=obj["content_type"])
170+
except ValueError as e:
171+
logger.error(e)
167172

168173
raise ValueError(
169174
(

src/sagemaker/serve/builder/transformers_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _create_transformers_model(self) -> Type[Model]:
9494
)
9595
hf_config = image_uris.config_for_framework("huggingface").get("inference")
9696
config = hf_config["versions"]
97-
base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0]
97+
base_hf_version = sorted(config.keys(), key=lambda v: Version(v), reverse=True)[0]
9898

9999
if hf_model_md is None:
100100
raise ValueError("Could not fetch HF metadata")
@@ -269,7 +269,7 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
269269
if len(hugging_face_version.split(".")) == 2:
270270
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
271271
versions_to_return.append(base_fw_version)
272-
return sorted(versions_to_return)[0]
272+
return sorted(versions_to_return, reverse=True)[0]
273273

274274
def _build_for_transformers(self):
275275
"""Method that triggers model build

tests/integ/sagemaker/serve/test_schema_builder.py

+27
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,33 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode(
202202
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
203203

204204

205+
@pytest.mark.skipif(
206+
PYTHON_VERSION_IS_NOT_310,
207+
reason="Testing Schema Builder Simplification feature - Remote Schema",
208+
)
209+
@pytest.mark.parametrize(
210+
"model_id, task_provided, instance_type_provided",
211+
[("openai/whisper-tiny.en", "automatic-speech-recognition", "ml.m5.4xlarge")],
212+
)
213+
def test_model_builder_with_task_provided_remote_schema_mode_asr(
214+
model_id, task_provided, sagemaker_session, instance_type_provided
215+
):
216+
model_builder = ModelBuilder(
217+
model=model_id,
218+
model_metadata={"HF_TASK": task_provided},
219+
instance_type=instance_type_provided,
220+
)
221+
model = model_builder.build(sagemaker_session=sagemaker_session)
222+
223+
assert model is not None
224+
assert model_builder.schema_builder is not None
225+
226+
remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever()
227+
inputs, outputs = remote_hf_schema_helper.get_resolved_hf_schema_for_task(task_provided)
228+
assert model_builder.schema_builder.sample_input == inputs
229+
assert model_builder.schema_builder.sample_output == outputs
230+
231+
205232
def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
206233
model_builder = ModelBuilder(
207234
model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}

0 commit comments

Comments
 (0)