From bc30012424022641915c01e7df98765515a57813 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 8 Jul 2021 16:49:20 +0200 Subject: [PATCH] Upgrading asteroid image. --- .../docker_images/asteroid/app/main.py | 25 +++++++++++++------ .../docker_images/asteroid/requirements.txt | 2 +- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/api-inference-community/docker_images/asteroid/app/main.py b/api-inference-community/docker_images/asteroid/app/main.py index 3202a1af76..9bd89115f5 100644 --- a/api-inference-community/docker_images/asteroid/app/main.py +++ b/api-inference-community/docker_images/asteroid/app/main.py @@ -1,3 +1,4 @@ +import functools import logging import os from typing import Dict, Type @@ -37,7 +38,10 @@ } -def get_pipeline(task: str, model_id: str) -> Pipeline: +@functools.lru_cache() +def get_pipeline() -> Pipeline: + task = os.environ["TASK"] + model_id = os.environ["MODEL_ID"] if task not in ALLOWED_TASKS: raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}") return ALLOWED_TASKS[task](model_id) @@ -71,13 +75,18 @@ async def startup_event(): handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) logger.handlers = [handler] - task = os.environ["TASK"] - model_id = os.environ["MODEL_ID"] - app.pipeline = get_pipeline(task, model_id) + # Link between `api-inference-community` and framework code. + app.get_pipeline = get_pipeline + try: + get_pipeline() + except Exception: + # We can fail so we can show exception later. + pass if __name__ == "__main__": - task = os.environ["TASK"] - model_id = os.environ["MODEL_ID"] - - get_pipeline(task, model_id) + try: + get_pipeline() + except Exception: + # We can fail so we can show exception later. + pass diff --git a/api-inference-community/docker_images/asteroid/requirements.txt b/api-inference-community/docker_images/asteroid/requirements.txt index 2ae6d8025e..fce9dd6bc5 100644 --- a/api-inference-community/docker_images/asteroid/requirements.txt +++ b/api-inference-community/docker_images/asteroid/requirements.txt @@ -1,3 +1,3 @@ starlette==0.14.2 -api-inference-community==0.0.7 +api-inference-community==0.0.9 asteroid==0.4.4