diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 58d2700276..d09177249e 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -289,6 +289,8 @@ def post( # ...or wait 1s and retry logger.info(f"Waiting for model to be loaded on the server: {error}") time.sleep(1) + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" if timeout is not None: timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore continue diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 1c674c4007..d7571de176 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -284,6 +284,8 @@ async def post( ) from error # ...or wait 1s and retry logger.info(f"Waiting for model to be loaded on the server: {error}") + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" time.sleep(1) if timeout is not None: timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore diff --git a/utils/generate_async_inference_client.py b/utils/generate_async_inference_client.py index e761f999ad..eaa0d7f04d 100644 --- a/utils/generate_async_inference_client.py +++ b/utils/generate_async_inference_client.py @@ -232,6 +232,8 @@ def _rename_to_AsyncInferenceClient(code: str) -> str: ) from error # ...or wait 1s and retry logger.info(f"Waiting for model to be loaded on the server: {error}") + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" time.sleep(1) if timeout is not None: timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore