Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add formatting and upgrade Sentence Transformers api version for better error messages #119

Merged
merged 2 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions api-inference-community/api_inference_community/routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import logging
import os
import time
import base64
from typing import Any, Dict

from api_inference_community.validation import ffmpeg_convert, normalize_payload
Expand Down Expand Up @@ -100,7 +100,13 @@ def emit(self, record):
headers["content-type"] = "application/json"
for waveform, label in zip(waveforms, labels):
data = ffmpeg_convert(waveform, sampling_rate)
items.append({"label": label, "blob": base64.b64encode(data).decode("utf-8"), "content-type": "audio/flac"})
items.append(
{
"label": label,
"blob": base64.b64encode(data).decode("utf-8"),
"content-type": "audio/flac",
}
)
return JSONResponse(items, headers=headers, status_code=status_code)

return JSONResponse(
Expand Down
4 changes: 2 additions & 2 deletions api-inference-community/docker_images/common/app/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import functools
import logging
import os
from typing import Dict, Type

from api_inference_community.routes import pipeline_route, status_ok
from app.pipelines import Pipeline
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.applications import Starlette
from starlette.routing import Route


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List
from typing import List, Tuple

import numpy as np
from app.pipelines import Pipeline
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import json
import base64
import json
import os
from unittest import TestCase, skipIf

from api_inference_community.validation import ffmpeg_read
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import os
from typing import Dict, Type
Expand All @@ -9,11 +10,18 @@
SentenceSimilarityPipeline,
)
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.routing import Route


TASK = os.getenv("TASK")
MODEL_ID = os.getenv("MODEL_ID")


logger = logging.getLogger(__name__)


# Add the allowed tasks
# Supported tasks are:
# - text-generation
Expand All @@ -34,7 +42,10 @@
}


def get_pipeline(task: str, model_id: str) -> Pipeline:
@functools.cache
def get_pipeline() -> Pipeline:
task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]
if task not in ALLOWED_TASKS:
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
return ALLOWED_TASKS[task](model_id)
Expand All @@ -45,14 +56,21 @@ def get_pipeline(task: str, model_id: str) -> Pipeline:
Route("/{whatever:path}", pipeline_route, methods=["POST"]),
]

app = Starlette(routes=routes)
middleware = [Middleware(GZipMiddleware, minimum_size=1000)]
if os.environ.get("DEBUG", "") == "1":
from starlette.middleware.cors import CORSMiddleware

app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"]
middleware.append(
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
)
)

app = Starlette(routes=routes, middleware=middleware)


@app.on_event("startup")
async def startup_event():
Expand All @@ -61,13 +79,13 @@ async def startup_event():
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger.handlers = [handler]

task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]
app.pipeline = get_pipeline(task, model_id)
# Link between `api-inference-community` and framework code.
app.get_pipeline = get_pipeline


if __name__ == "__main__":
task = os.environ["TASK"]
model_id = os.environ["MODEL_ID"]

get_pipeline(task, model_id)
try:
get_pipeline()
except Exception:
# We can fail so we can show exception later.
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
starlette==0.14.2
api-inference-community==0.0.6
api-inference-community==0.0.7
-e git+https://github.com/osanseviero/sentence-transformers@v2_dev#egg=sentence-transformers
protobuf==3.17.1
18 changes: 13 additions & 5 deletions api-inference-community/manage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python
import argparse
import ast
import os
import subprocess
import uuid
import ast


class cd:
Expand Down Expand Up @@ -38,19 +38,27 @@ def create_docker(name: str) -> str:


def show(args):
directory = os.path.join(os.path.dirname(os.path.dirname(__file__)), "docker_images")
directory = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "docker_images"
)
for framework in sorted(os.listdir(directory)):
print(f"{framework}")
local_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "docker_images", framework,
"app", "main.py"
os.path.dirname(os.path.dirname(__file__)),
"docker_images",
framework,
"app",
"main.py",
)
# Using ast to prevent import issues with missing dependencies.
# and slow loads.
with open(local_path, "r") as source:
tree = ast.parse(source.read())
for item in tree.body:
if isinstance(item, ast.AnnAssign) and item.target.id == "ALLOWED_TASKS":
if (
isinstance(item, ast.AnnAssign)
and item.target.id == "ALLOWED_TASKS"
):
for key in item.value.keys:
print(" " * 4, key.value)

Expand Down
9 changes: 5 additions & 4 deletions api-inference-community/tests/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import json
import numpy as np
from unittest import TestCase
import logging
from api_inference_community.routes import status_ok, pipeline_route
import os
from unittest import TestCase

import numpy as np
from api_inference_community.routes import pipeline_route, status_ok
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.testclient import TestClient
Expand Down