Skip to content

Commit

Permalink
Simplify manage to autodetect task+framework if possible. (#122)
Browse files Browse the repository at this point in the history
* Simplify manage to autodetect task+framework if possible.

* Style.

* Adding huggingface_hub as test dependency.

* Actions should run on PRs not on push.

* Fix failing test.

* Alphabetical order.

* Using library_name
  • Loading branch information
Narsil authored Jun 21, 2021
1 parent c6402fd commit af528fd
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-api-allennlp.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: allennlp-docker

on:
push:
pull-request:
paths:
- "api-inference-community/docker_images/allennlp/**"
jobs:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-api-quality.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Inference API code quality

on:
push:
pull_request:
paths:
- "api-inference-community/**"

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-api-tests.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Inference API Python-tests
on:
push:
pull_request:
paths:
- "api-inference-community/**"
jobs:
Expand All @@ -25,7 +25,7 @@ jobs:
working-directory: api-inference-community
run: |
pip install --upgrade pip
pip install pytest pillow httpx
pip install pytest pillow httpx huggingface_hub
pip install -e .
- run: make test
working-directory: api-inference-community
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class StringInput(BaseModel):
"token-classification": StringInput,
"translation": StringInput,
"zero-shot-classification": StringInput,
"text-to-speech": StringInput,
}

BATCH_ENABLED_PIPELINES = ["feature-extraction"]
Expand Down
27 changes: 23 additions & 4 deletions api-inference-community/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import subprocess
import uuid

from huggingface_hub import HfApi


class cd:
"""Context manager for changing the current working directory"""
Expand Down Expand Up @@ -63,6 +65,13 @@ def show(args):
print(" " * 4, key.value)


def resolve(model_id: str) -> [str, str]:
info = HfApi().model_info(model_id)
task = info.pipeline_tag
framework = info.library_name
return task, framework.replace("-", "_")


def start(args):
import sys

Expand All @@ -71,6 +80,14 @@ def start(args):
model_id = args.model_id
task = args.task
framework = args.framework
if task is None or framework is None:
rtask, rframework = resolve(model_id)
if task is None:
task = rtask
print(f"Inferred task : {task}")
if framework is None:
framework = rframework
print(f"Inferred framework : {framework}")

local_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "docker_images", framework
Expand All @@ -85,6 +102,12 @@ def docker(args):
model_id = args.model_id
task = args.task
framework = args.framework
if task is None or framework is None:
rtask, rframework = resolve(model_id)
if task is None:
task = rtask
if framework is None:
framework = rframework

tag = create_docker(framework)
run_docker_command = [
Expand Down Expand Up @@ -124,13 +147,11 @@ def main():
parser_start.add_argument(
"--task",
type=str,
required=True,
help="Which task to load",
)
parser_start.add_argument(
"--framework",
type=str,
required=True,
help="Which framework to load",
)
parser_start.set_defaults(func=start)
Expand All @@ -146,13 +167,11 @@ def main():
parser_docker.add_argument(
"--task",
type=str,
required=True,
help="Which task to load",
)
parser_docker.add_argument(
"--framework",
type=str,
required=True,
help="Which framework to load",
)
parser_docker.set_defaults(func=docker)
Expand Down
2 changes: 1 addition & 1 deletion api-inference-community/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def startup_event():
app.get_pipeline = get_pipeline

with TestClient(app) as client:
response = client.post("/", data=b"")
response = client.post("/", data=b"2222")

self.assertEqual(
response.status_code,
Expand Down

0 comments on commit af528fd

Please # to comment.